diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..c1b61ec --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[run] +omit = + api/test_enhanced_api.py + +[report] +omit = + api/test_enhanced_api.py diff --git a/tests/conftest.py b/tests/conftest.py index 6b453d0..8330494 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,7 @@ def tmp_dir(): @pytest.fixture def tmp_config_dir(tmp_dir): """Temporary config dir with the sub-directories expected by managers.""" - for sub in ('api', 'caddy', 'dns', 'dhcp', 'ntp', 'wireguard'): + for sub in ('api', 'caddy', 'dns', 'ntp', 'wireguard'): os.makedirs(os.path.join(tmp_dir, sub), exist_ok=True) return tmp_dir diff --git a/tests/integration/test_live_api.py b/tests/integration/test_live_api.py index daa879a..5bbb086 100644 --- a/tests/integration/test_live_api.py +++ b/tests/integration/test_live_api.py @@ -90,7 +90,7 @@ class TestConfig: # --------------------------------------------------------------------------- EXPECTED_CONTAINERS = [ - 'cell-caddy', 'cell-dns', 'cell-dhcp', 'cell-ntp', + 'cell-caddy', 'cell-dns', 'cell-ntp', 'cell-mail', 'cell-radicale', 'cell-webdav', 'cell-wireguard', 'cell-api', 'cell-webui', 'cell-rainloop', 'cell-filegator', ] @@ -164,7 +164,7 @@ class TestWireGuard: # --------------------------------------------------------------------------- -# Network services: DNS, DHCP, NTP +# Network services: DNS, NTP # --------------------------------------------------------------------------- class TestNetworkServices: @@ -176,8 +176,8 @@ class TestNetworkServices: r = get('/api/dns/status') assert r.status_code == 200 - def test_dhcp_leases_endpoint(self): - r = get('/api/dhcp/leases') + def test_dns_overview_endpoint(self): + r = get('/api/dns/overview') assert r.status_code == 200 def test_ntp_status_endpoint(self): diff --git a/tests/integration/test_negative_scenarios.py b/tests/integration/test_negative_scenarios.py index 2150fbc..fbe3a0e 100644 --- a/tests/integration/test_negative_scenarios.py +++ b/tests/integration/test_negative_scenarios.py @@ -11,7 +11,6 @@ Endpoints covered: - /api/peers (POST, PUT, DELETE) - /api/config (PUT) - /api/dns/records (DELETE) - - /api/dhcp/reservations (POST, DELETE) - /api/containers//restart - /api/wireguard/keys/peer @@ -240,43 +239,6 @@ class TestDnsRecordsNegative: r.json() -# --------------------------------------------------------------------------- -# DHCP reservations — negative -# --------------------------------------------------------------------------- - -class TestDhcpReservationsNegative: - def test_add_reservation_no_body_returns_400(self): - r = _S.post( - f"{API_BASE}/api/dhcp/reservations", - data='', - headers={'Content-Type': 'application/json'}, - ) - assert r.status_code == 400 - - def test_add_reservation_missing_ip_returns_400(self): - r = post('/api/dhcp/reservations', json={'mac': 'aa:bb:cc:dd:ee:ff'}) - assert r.status_code == 400 - _assert_json_error(r) - - def test_add_reservation_missing_mac_returns_400(self): - r = post('/api/dhcp/reservations', json={'ip': '10.0.0.250'}) - assert r.status_code == 400 - _assert_json_error(r) - - def test_delete_reservation_no_mac_returns_400(self): - r = delete('/api/dhcp/reservations', json={'ip': '10.0.0.250'}) - assert r.status_code == 400 - _assert_json_error(r) - - def test_delete_reservation_empty_body_returns_400(self): - r = _S.delete( - f"{API_BASE}/api/dhcp/reservations", - data='', - headers={'Content-Type': 'application/json'}, - ) - assert r.status_code == 400 - - # --------------------------------------------------------------------------- # Container endpoints — negative # --------------------------------------------------------------------------- diff --git a/tests/integration/test_network_services.py b/tests/integration/test_network_services.py index 7277a6d..6c0e3b3 100644 --- a/tests/integration/test_network_services.py +++ b/tests/integration/test_network_services.py @@ -1,10 +1,8 @@ """ -Network services integration tests: DNS records, DHCP leases, DHCP reservations. +Network services integration tests: DNS records, DNS overview. Note on endpoint shapes discovered from app.py: - - DELETE /api/dns/records takes a JSON body (not a URL param) - - DELETE /api/dhcp/reservations takes JSON body with 'mac' field - - POST /api/dhcp/reservations requires 'mac' and 'ip' fields + - DELETE /api/dns/records takes a JSON body (not a URL param) Run with: pytest tests/integration/test_network_services.py -v """ @@ -129,79 +127,20 @@ class TestDnsRecordsWrite: # --------------------------------------------------------------------------- -# GET /api/dhcp/leases +# GET /api/dns/overview # --------------------------------------------------------------------------- -class TestDhcpLeases: - def test_get_dhcp_leases_returns_200(self): - r = get('/api/dhcp/leases') +class TestDnsOverview: + def test_get_dns_overview_returns_200(self): + r = get('/api/dns/overview') assert r.status_code == 200 - def test_get_dhcp_leases_returns_list_or_dict(self): - data = get('/api/dhcp/leases').json() - assert isinstance(data, (list, dict)) - - -# --------------------------------------------------------------------------- -# POST /api/dhcp/reservations + DELETE /api/dhcp/reservations -# --------------------------------------------------------------------------- - -_TEST_MAC = 'de:ad:be:ef:11:22' -_TEST_RESERVATION_IP = '10.0.0.200' - - -class TestDhcpReservations: - def _cleanup(self): - delete('/api/dhcp/reservations', json={'mac': _TEST_MAC}) - - def test_add_dhcp_reservation_returns_non_error(self): - try: - r = post('/api/dhcp/reservations', json={ - 'mac': _TEST_MAC, - 'ip': _TEST_RESERVATION_IP, - 'hostname': 'inttest-dhcp-host', - }) - assert r.status_code in (200, 201), ( - f"Expected 200/201 for DHCP reservation, got {r.status_code}: {r.text}" - ) - finally: - self._cleanup() - - def test_add_dhcp_reservation_missing_mac_returns_400(self): - r = post('/api/dhcp/reservations', json={'ip': _TEST_RESERVATION_IP}) - assert r.status_code == 400 - assert 'error' in r.json() - - def test_add_dhcp_reservation_missing_ip_returns_400(self): - r = post('/api/dhcp/reservations', json={'mac': _TEST_MAC}) - assert r.status_code == 400 - assert 'error' in r.json() - - def test_add_dhcp_reservation_empty_body_returns_400(self): - r = post('/api/dhcp/reservations', data='') - assert r.status_code == 400 - - def test_delete_dhcp_reservation_missing_mac_returns_400(self): - r = delete('/api/dhcp/reservations', json={}) - assert r.status_code == 400 - assert 'error' in r.json() - - def test_add_and_delete_dhcp_reservation_round_trip(self): - add_r = post('/api/dhcp/reservations', json={ - 'mac': _TEST_MAC, - 'ip': _TEST_RESERVATION_IP, - }) - assert add_r.status_code in (200, 201), ( - f"Could not create DHCP reservation: {add_r.text}" - ) - try: - del_r = delete('/api/dhcp/reservations', json={'mac': _TEST_MAC}) - assert del_r.status_code in (200, 204), ( - f"DHCP reservation delete failed: {del_r.status_code} {del_r.text}" - ) - except Exception: - self._cleanup() - raise + def test_get_dns_overview_has_expected_keys(self): + data = get('/api/dns/overview').json() + assert isinstance(data, dict) + for key in ('mode', 'effective_domain', 'internal_domain', + 'public_records', 'internal_records'): + assert key in data # --------------------------------------------------------------------------- diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py index 05268d3..5978960 100644 --- a/tests/test_api_endpoints.py +++ b/tests/test_api_endpoints.py @@ -160,37 +160,6 @@ class TestAPIEndpoints(unittest.TestCase): response = self.client.delete('/api/dns/records', data=json.dumps({'name': 'test'}), content_type='application/json') self.assertEqual(response.status_code, 500) - @patch('app.network_manager') - def test_dhcp_endpoints(self, mock_network): - # Mock get_dhcp_leases - mock_network.get_dhcp_leases.return_value = [{'ip': '10.0.0.2', 'mac': '00:11:22:33:44:55'}] - response = self.client.get('/api/dhcp/leases') - self.assertEqual(response.status_code, 200) - data = json.loads(response.data) - self.assertIsInstance(data, list) - # Mock add_dhcp_reservation - mock_network.add_dhcp_reservation.return_value = True - response = self.client.post('/api/dhcp/reservations', data=json.dumps({'ip': '10.0.0.2', 'mac': '00:11:22:33:44:55'}), content_type='application/json') - self.assertEqual(response.status_code, 200) - # Missing mac field → 400, not 500 - response = self.client.post('/api/dhcp/reservations', data=json.dumps({'ip': '10.0.0.2'}), content_type='application/json') - self.assertEqual(response.status_code, 400) - # Simulate manager error - mock_network.add_dhcp_reservation.side_effect = Exception('fail') - response = self.client.post('/api/dhcp/reservations', data=json.dumps({'ip': '10.0.0.2', 'mac': '00:11:22:33:44:55'}), content_type='application/json') - self.assertEqual(response.status_code, 500) - # Mock remove_dhcp_reservation - mock_network.remove_dhcp_reservation.return_value = True - response = self.client.delete('/api/dhcp/reservations', data=json.dumps({'mac': '00:11:22:33:44:55'}), content_type='application/json') - self.assertEqual(response.status_code, 200) - # Missing mac → 400 - response = self.client.delete('/api/dhcp/reservations', data=json.dumps({'ip': '10.0.0.2'}), content_type='application/json') - self.assertEqual(response.status_code, 400) - # Simulate manager error - mock_network.remove_dhcp_reservation.side_effect = Exception('fail') - response = self.client.delete('/api/dhcp/reservations', data=json.dumps({'mac': '00:11:22:33:44:55'}), content_type='application/json') - self.assertEqual(response.status_code, 500) - @patch('app.network_manager') def test_ntp_status_endpoint(self, mock_network): # Mock get_ntp_status diff --git a/tests/test_app_health_connectivity.py b/tests/test_app_health_connectivity.py new file mode 100644 index 0000000..2ec1e30 --- /dev/null +++ b/tests/test_app_health_connectivity.py @@ -0,0 +1,650 @@ +""" +Tests for app.py: health_history (deque), health monitor logic, +connectivity endpoints, caddy endpoints, egress endpoints, +and before-request hooks (enforce_setup/enforce_auth/check_csrf). +""" +import sys +from pathlib import Path +import json +from collections import deque +from unittest.mock import patch, MagicMock + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent / 'api')) + +import app as app_module +from app import app + + +@pytest.fixture(autouse=True) +def reset_app_state(): + """Reset global mutable state between tests.""" + orig_running = app_module.health_monitor_running + orig_counters = dict(app_module.service_alert_counters) + app.config['TESTING'] = True + yield + app_module.health_monitor_running = orig_running + app_module.service_alert_counters = orig_counters + + +@pytest.fixture +def client(): + app.config['TESTING'] = True + with app.test_client() as c: + yield c + + +# --------------------------------------------------------------------------- +# health_history is a deque (not a list) +# --------------------------------------------------------------------------- + +class TestHealthHistoryIsDeque: + def test_health_history_is_deque(self): + assert isinstance(app_module.health_history, deque) + + def test_health_history_has_maxlen(self): + assert app_module.health_history.maxlen == app_module.HEALTH_HISTORY_SIZE + + def test_health_history_appendleft_works(self): + """appendleft (used in health_monitor_loop) should work on a deque.""" + hh = app_module.health_history + entry = {'timestamp': '2026-01-01T00:00:00', 'alerts': []} + hh.appendleft(entry) + assert hh[0] == entry + + def test_health_history_maxlen_evicts_old_entries(self): + hh = deque(maxlen=3) + for i in range(5): + hh.appendleft({'n': i}) + assert len(hh) == 3 + # Most recent is first + assert hh[0]['n'] == 4 + + +# --------------------------------------------------------------------------- +# GET /api/health/history +# --------------------------------------------------------------------------- + +class TestGetHealthHistory: + def test_returns_200(self, client): + with patch.object(app_module, 'health_history', deque(maxlen=100)): + resp = client.get('/api/health/history') + assert resp.status_code == 200 + + def test_returns_list(self, client): + with patch.object(app_module, 'health_history', deque(maxlen=100)): + resp = client.get('/api/health/history') + data = json.loads(resp.data) + assert isinstance(data, list) + + def test_returns_stored_entries(self, client): + hh = deque(maxlen=100) + hh.appendleft({'timestamp': 't1', 'alerts': []}) + hh.appendleft({'timestamp': 't2', 'alerts': []}) + with patch.object(app_module, 'health_history', hh): + resp = client.get('/api/health/history') + data = json.loads(resp.data) + assert len(data) == 2 + + def test_returns_empty_when_no_history(self, client): + with patch.object(app_module, 'health_history', deque(maxlen=100)): + resp = client.get('/api/health/history') + assert json.loads(resp.data) == [] + + +# --------------------------------------------------------------------------- +# POST /api/health/history/clear +# --------------------------------------------------------------------------- + +class TestClearHealthHistory: + def test_clear_returns_200(self, client): + hh = deque(maxlen=100) + hh.appendleft({'entry': 1}) + with patch.object(app_module, 'health_history', hh): + resp = client.post('/api/health/history/clear') + assert resp.status_code == 200 + + def test_clear_empties_history(self, client): + hh = deque(maxlen=100) + hh.appendleft({'entry': 1}) + with patch.object(app_module, 'health_history', hh): + client.post('/api/health/history/clear') + assert len(hh) == 0 + + def test_clear_resets_alert_counters(self, client): + app_module.service_alert_counters['network'] = 5 + hh = deque(maxlen=100) + with patch.object(app_module, 'health_history', hh): + client.post('/api/health/history/clear') + assert app_module.service_alert_counters == {} + + def test_clear_response_has_message(self, client): + hh = deque(maxlen=100) + with patch.object(app_module, 'health_history', hh): + resp = client.post('/api/health/history/clear') + data = json.loads(resp.data) + assert 'message' in data + + +# --------------------------------------------------------------------------- +# perform_health_check alerting logic +# --------------------------------------------------------------------------- + +class TestPerformHealthCheck: + def test_healthy_service_resets_counter(self): + app_module.service_alert_counters['network'] = 2 + mock_service_bus = MagicMock() + mock_service_bus.list_services.return_value = ['network'] + network_svc = MagicMock() + network_svc.health_check.return_value = {'running': True} + mock_service_bus.get_service.return_value = network_svc + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = [] + with patch.object(app_module, 'service_bus', mock_service_bus), \ + patch.object(app_module, 'config_manager', mock_cfg), \ + app.app_context(): + result = app_module.perform_health_check() + assert app_module.service_alert_counters.get('network', 0) == 0 + assert 'network' in result + + def test_unhealthy_service_with_error_key_increments_counter(self): + """Services that raise an exception get recorded with an 'error' key, + which the alerting logic recognises as unhealthy.""" + app_module.service_alert_counters = {} + mock_service_bus = MagicMock() + mock_service_bus.list_services.return_value = ['network'] + mock_service_bus.publish_event = MagicMock() + network_svc = MagicMock() + # Raise so the result gets {'error': ..., 'status': 'offline'} + network_svc.health_check.side_effect = Exception('container down') + mock_service_bus.get_service.return_value = network_svc + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = [] + with patch.object(app_module, 'service_bus', mock_service_bus), \ + patch.object(app_module, 'config_manager', mock_cfg), \ + app.app_context(): + app_module.perform_health_check() + # With an 'error' key and no 'running' key, healthy=False → counter increments + assert app_module.service_alert_counters.get('network', 0) == 1 + + def test_alert_triggered_at_threshold(self): + """Counter reaching HEALTH_ALERT_THRESHOLD emits an alert.""" + app_module.service_alert_counters = {'network': app_module.HEALTH_ALERT_THRESHOLD - 1} + mock_service_bus = MagicMock() + mock_service_bus.list_services.return_value = ['network'] + mock_service_bus.publish_event = MagicMock() + network_svc = MagicMock() + # Use exception path to guarantee healthy=False + network_svc.health_check.side_effect = Exception('container down') + mock_service_bus.get_service.return_value = network_svc + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = [] + with patch.object(app_module, 'service_bus', mock_service_bus), \ + patch.object(app_module, 'config_manager', mock_cfg), \ + app.app_context(): + result = app_module.perform_health_check() + # Alert should be in result['alerts'] + assert len(result['alerts']) >= 1 + assert any('network' in a for a in result['alerts']) + + def test_optional_store_services_skipped_when_not_installed(self): + mock_service_bus = MagicMock() + mock_service_bus.list_services.return_value = ['email_manager'] + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = [] # email not installed + with patch.object(app_module, 'service_bus', mock_service_bus), \ + patch.object(app_module, 'config_manager', mock_cfg), \ + app.app_context(): + result = app_module.perform_health_check() + # email_manager should not appear in result (was skipped) + assert 'email_manager' not in result + + def test_optional_store_service_checked_when_installed(self): + mock_service_bus = MagicMock() + mock_service_bus.list_services.return_value = ['email_manager'] + mock_service_bus.publish_event = MagicMock() + email_svc = MagicMock() + email_svc.health_check.return_value = {'running': True} + mock_service_bus.get_service.return_value = email_svc + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = ['email'] # email installed + with patch.object(app_module, 'service_bus', mock_service_bus), \ + patch.object(app_module, 'config_manager', mock_cfg), \ + app.app_context(): + result = app_module.perform_health_check() + assert 'email_manager' in result + + def test_service_without_health_check_falls_back_to_get_status(self): + mock_service_bus = MagicMock() + mock_service_bus.list_services.return_value = ['routing'] + svc = MagicMock(spec=[]) # no health_check attribute + svc.get_status = MagicMock(return_value={'running': True}) + mock_service_bus.get_service.return_value = svc + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = [] + with patch.object(app_module, 'service_bus', mock_service_bus), \ + patch.object(app_module, 'config_manager', mock_cfg), \ + app.app_context(): + result = app_module.perform_health_check() + assert 'routing' in result + + def test_service_exception_recorded_as_error(self): + mock_service_bus = MagicMock() + mock_service_bus.list_services.return_value = ['vault'] + svc = MagicMock() + svc.health_check.side_effect = Exception('vault down') + mock_service_bus.get_service.return_value = svc + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = [] + with patch.object(app_module, 'service_bus', mock_service_bus), \ + patch.object(app_module, 'config_manager', mock_cfg), \ + app.app_context(): + result = app_module.perform_health_check() + assert 'error' in result.get('vault', {}) + + +# --------------------------------------------------------------------------- +# GET /api/connectivity/status +# --------------------------------------------------------------------------- + +class TestConnectivityEndpoints: + def test_connectivity_status_200(self, client): + mock_cm = MagicMock() + mock_cm.get_status.return_value = {'exits': [], 'peers': {}} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.get('/api/connectivity/status') + assert resp.status_code == 200 + + def test_connectivity_status_shape(self, client): + mock_cm = MagicMock() + mock_cm.get_status.return_value = {'exits': [], 'peers': {}} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.get('/api/connectivity/status') + data = json.loads(resp.data) + assert 'exits' in data + + def test_connectivity_status_500_on_exception(self, client): + mock_cm = MagicMock() + mock_cm.get_status.side_effect = Exception('fail') + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.get('/api/connectivity/status') + assert resp.status_code == 500 + + def test_connectivity_list_exits_200(self, client): + mock_cm = MagicMock() + mock_cm.list_exits.return_value = [] + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.get('/api/connectivity/exits') + assert resp.status_code == 200 + + def test_connectivity_list_exits_shape(self, client): + mock_cm = MagicMock() + mock_cm.list_exits.return_value = [{'type': 'wireguard_ext', 'name': 'exit1'}] + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.get('/api/connectivity/exits') + data = json.loads(resp.data) + assert 'exits' in data + assert len(data['exits']) == 1 + + def test_connectivity_upload_wireguard_missing_conf_text(self, client): + resp = client.post('/api/connectivity/exits/wireguard', + data=json.dumps({}), content_type='application/json') + assert resp.status_code == 400 + data = json.loads(resp.data) + assert 'error' in data + + def test_connectivity_upload_wireguard_empty_conf_text(self, client): + resp = client.post('/api/connectivity/exits/wireguard', + data=json.dumps({'conf_text': ' '}), + content_type='application/json') + assert resp.status_code == 400 + + def test_connectivity_upload_wireguard_success(self, client): + mock_cm = MagicMock() + mock_cm.upload_wireguard_ext.return_value = {'ok': True, 'message': 'Uploaded'} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.post('/api/connectivity/exits/wireguard', + data=json.dumps({'conf_text': '[Interface]\nPrivateKey = abc\n'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_connectivity_upload_wireguard_failure(self, client): + mock_cm = MagicMock() + mock_cm.upload_wireguard_ext.return_value = {'ok': False, 'error': 'bad config'} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.post('/api/connectivity/exits/wireguard', + data=json.dumps({'conf_text': '[Interface]\nPrivateKey = abc\n'}), + content_type='application/json') + assert resp.status_code == 400 + + def test_connectivity_upload_openvpn_missing_ovpn_text(self, client): + resp = client.post('/api/connectivity/exits/openvpn', + data=json.dumps({}), content_type='application/json') + assert resp.status_code == 400 + + def test_connectivity_upload_openvpn_success(self, client): + mock_cm = MagicMock() + mock_cm.upload_openvpn.return_value = {'ok': True} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.post('/api/connectivity/exits/openvpn', + data=json.dumps({'ovpn_text': 'client\ndev tun\n'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_connectivity_apply_routes_200(self, client): + mock_cm = MagicMock() + mock_cm.apply_routes.return_value = {'ok': True, 'applied': 0} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.post('/api/connectivity/exits/apply', + content_type='application/json') + assert resp.status_code == 200 + + def test_connectivity_set_peer_exit_missing_exit_via(self, client): + resp = client.put('/api/connectivity/peers/alice/exit', + data=json.dumps({}), content_type='application/json') + assert resp.status_code == 400 + + def test_connectivity_set_peer_exit_success(self, client): + mock_cm = MagicMock() + mock_cm.set_peer_exit.return_value = {'ok': True} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.put('/api/connectivity/peers/alice/exit', + data=json.dumps({'exit_via': 'wireguard_ext'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_connectivity_set_peer_exit_failure(self, client): + mock_cm = MagicMock() + mock_cm.set_peer_exit.return_value = {'ok': False, 'error': 'not found'} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.put('/api/connectivity/peers/alice/exit', + data=json.dumps({'exit_via': 'wireguard_ext'}), + content_type='application/json') + assert resp.status_code == 400 + + def test_connectivity_get_peer_exits_200(self, client): + mock_cm = MagicMock() + mock_cm.get_peer_exits.return_value = {'alice': 'wireguard_ext'} + with patch.object(app_module, 'connectivity_manager', mock_cm): + resp = client.get('/api/connectivity/peers') + assert resp.status_code == 200 + data = json.loads(resp.data) + assert 'peers' in data + + +# --------------------------------------------------------------------------- +# GET /api/caddy/cert-status and POST /api/caddy/cert-renew +# --------------------------------------------------------------------------- + +class TestCaddyEndpoints: + def test_caddy_cert_status_200(self, client): + mock_caddy = MagicMock() + mock_caddy.get_cert_status_fresh.return_value = {'status': 'valid', 'days_remaining': 60} + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.get('/api/caddy/cert-status') + assert resp.status_code == 200 + + def test_caddy_cert_status_shape(self, client): + mock_caddy = MagicMock() + mock_caddy.get_cert_status_fresh.return_value = {'status': 'internal'} + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.get('/api/caddy/cert-status') + data = json.loads(resp.data) + assert 'status' in data + + def test_caddy_cert_status_500_on_exception(self, client): + mock_caddy = MagicMock() + mock_caddy.get_cert_status_fresh.side_effect = Exception('Caddy unreachable') + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.get('/api/caddy/cert-status') + assert resp.status_code == 500 + + def test_caddy_cert_renew_success(self, client): + mock_caddy = MagicMock() + mock_caddy.renew_cert.return_value = {'ok': True, 'status': 'pending'} + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.post('/api/caddy/cert-renew', + content_type='application/json') + assert resp.status_code == 200 + + def test_caddy_cert_renew_failure(self, client): + mock_caddy = MagicMock() + mock_caddy.renew_cert.return_value = {'ok': False, 'error': 'LAN mode'} + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.post('/api/caddy/cert-renew', + content_type='application/json') + assert resp.status_code == 400 + + def test_caddy_cert_renew_500_on_exception(self, client): + mock_caddy = MagicMock() + mock_caddy.renew_cert.side_effect = Exception('fail') + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.post('/api/caddy/cert-renew', + content_type='application/json') + assert resp.status_code == 500 + + def test_caddy_upload_custom_cert_missing_fields(self, client): + resp = client.post('/api/caddy/custom-cert', + data=json.dumps({}), content_type='application/json') + assert resp.status_code == 400 + + def test_caddy_upload_custom_cert_success(self, client): + mock_caddy = MagicMock() + mock_caddy.upload_custom_cert.return_value = {'ok': True} + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.post('/api/caddy/custom-cert', + data=json.dumps({'cert_pem': 'CERT', 'key_pem': 'KEY'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_caddy_upload_custom_cert_failure(self, client): + mock_caddy = MagicMock() + mock_caddy.upload_custom_cert.return_value = {'ok': False, 'error': 'invalid cert'} + with patch.object(app_module, 'caddy_manager', mock_caddy): + resp = client.post('/api/caddy/custom-cert', + data=json.dumps({'cert_pem': 'BAD', 'key_pem': 'BAD'}), + content_type='application/json') + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /api/egress/status and PUT /api/egress/services//exit +# --------------------------------------------------------------------------- + +class TestEgressEndpoints: + def test_egress_status_200(self, client): + mock_egress = MagicMock() + mock_egress.get_status.return_value = {'services': {}} + with patch('app.egress_manager', mock_egress, create=True): + resp = client.get('/api/egress/status') + assert resp.status_code == 200 + + def test_egress_status_500_on_exception(self, client): + mock_egress = MagicMock() + mock_egress.get_status.side_effect = Exception('fail') + with patch('app.egress_manager', mock_egress, create=True): + resp = client.get('/api/egress/status') + assert resp.status_code == 500 + + def test_egress_set_service_exit_missing_exit_type(self, client): + mock_egress = MagicMock() + with patch('app.egress_manager', mock_egress, create=True): + resp = client.put('/api/egress/services/email/exit', + data=json.dumps({}), content_type='application/json') + assert resp.status_code == 400 + + def test_egress_set_service_exit_success(self, client): + mock_egress = MagicMock() + mock_egress.set_service_exit.return_value = {'ok': True} + with patch('app.egress_manager', mock_egress, create=True): + resp = client.put('/api/egress/services/email/exit', + data=json.dumps({'exit_type': 'wireguard_ext'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_egress_set_service_exit_failure(self, client): + mock_egress = MagicMock() + mock_egress.set_service_exit.return_value = {'ok': False, 'error': 'not found'} + with patch('app.egress_manager', mock_egress, create=True): + resp = client.put('/api/egress/services/email/exit', + data=json.dumps({'exit_type': 'wireguard_ext'}), + content_type='application/json') + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# enforce_setup hook: returns 428 when setup is not complete +# --------------------------------------------------------------------------- + +class TestEnforceSetupHook: + def test_428_when_setup_incomplete(self): + """Without TESTING=True, API requests are blocked if setup is not done.""" + app.config['TESTING'] = False + mock_setup = MagicMock() + mock_setup.is_setup_complete.return_value = False + try: + with patch.object(app_module, 'setup_manager', mock_setup): + with app.test_client() as c: + resp = c.get('/api/status') + assert resp.status_code == 428 + data = json.loads(resp.data) + assert 'redirect' in data + finally: + app.config['TESTING'] = True + + def test_setup_route_passes_when_incomplete(self): + """Setup routes always pass through regardless of setup status.""" + app.config['TESTING'] = False + mock_setup = MagicMock() + mock_setup.is_setup_complete.return_value = False + try: + with patch.object(app_module, 'setup_manager', mock_setup): + with app.test_client() as c: + resp = c.get('/api/setup/status') + # Should NOT be 428 + assert resp.status_code != 428 + finally: + app.config['TESTING'] = True + + def test_health_passes_when_incomplete(self): + """The /health endpoint always passes through.""" + app.config['TESTING'] = False + mock_setup = MagicMock() + mock_setup.is_setup_complete.return_value = False + try: + with patch.object(app_module, 'setup_manager', mock_setup): + with app.test_client() as c: + resp = c.get('/health') + assert resp.status_code == 200 + finally: + app.config['TESTING'] = True + + def test_setup_complete_passes_through(self): + """All routes pass through when setup is complete.""" + app.config['TESTING'] = False + mock_setup = MagicMock() + mock_setup.is_setup_complete.return_value = True + mock_auth = MagicMock() + mock_auth.list_users.return_value = [] + try: + with patch.object(app_module, 'setup_manager', mock_setup), \ + patch.object(app_module, 'auth_manager', mock_auth): + with app.test_client() as c: + resp = c.get('/api/status') + assert resp.status_code != 428 + finally: + app.config['TESTING'] = True + + +# --------------------------------------------------------------------------- +# enforce_auth hook: 503 when users file exists but is empty +# --------------------------------------------------------------------------- + +class TestEnforceAuthHook: + def test_503_when_users_file_empty_and_readable(self, tmp_path): + """Returns 503 when users file exists + readable but has no accounts.""" + import tempfile, os + app.config['TESTING'] = False + users_file = tmp_path / 'auth_users.json' + users_file.write_text('[]') # file exists but no accounts + + from auth_manager import AuthManager + real_auth = MagicMock(spec=AuthManager) + real_auth.list_users.return_value = [] + real_auth._users_file = str(users_file) + + mock_setup = MagicMock() + mock_setup.is_setup_complete.return_value = True + try: + with patch.object(app_module, 'auth_manager', real_auth), \ + patch.object(app_module, 'setup_manager', mock_setup): + with app.test_client() as c: + resp = c.get('/api/status') + assert resp.status_code == 503 + data = json.loads(resp.data) + assert 'error' in data + finally: + app.config['TESTING'] = True + + def test_401_when_no_session_and_users_exist(self, tmp_path): + """Returns 401 when users exist but no session cookie is set.""" + app.config['TESTING'] = False + users_file = tmp_path / 'auth_users.json' + # Users file doesn't exist — no file means enforcement + # is bypassed. Use a file that DOES have a user. + import json as _json + users_file.write_text(_json.dumps([{'username': 'admin', 'role': 'admin'}])) + + from auth_manager import AuthManager + real_auth = MagicMock(spec=AuthManager) + real_auth.list_users.return_value = [{'username': 'admin', 'role': 'admin'}] + real_auth._users_file = str(users_file) + + mock_setup = MagicMock() + mock_setup.is_setup_complete.return_value = True + try: + with patch.object(app_module, 'auth_manager', real_auth), \ + patch.object(app_module, 'setup_manager', mock_setup): + with app.test_client() as c: + # No login — no session + resp = c.get('/api/status') + assert resp.status_code == 401 + finally: + app.config['TESTING'] = True + + +# --------------------------------------------------------------------------- +# GET /api/status +# --------------------------------------------------------------------------- + +class TestGetCellStatus: + def test_returns_200(self, client): + mock_sb = MagicMock() + mock_sb.list_services.return_value = [] + mock_pr = MagicMock() + mock_pr.list_peers.return_value = [] + mock_cm = MagicMock() + mock_cm.configs = {'_identity': {'cell_name': 'test', 'domain': 'cell'}} + mock_cm.get_effective_domain.return_value = 'cell' + with patch.object(app_module, 'service_bus', mock_sb), \ + patch.object(app_module, 'peer_registry', mock_pr), \ + patch.object(app_module, 'config_manager', mock_cm): + resp = client.get('/api/status') + assert resp.status_code == 200 + + def test_status_includes_expected_keys(self, client): + mock_sb = MagicMock() + mock_sb.list_services.return_value = [] + mock_pr = MagicMock() + mock_pr.list_peers.return_value = [] + mock_cm = MagicMock() + mock_cm.configs = {'_identity': {'cell_name': 'test', 'domain': 'cell'}} + mock_cm.get_effective_domain.return_value = 'cell' + with patch.object(app_module, 'service_bus', mock_sb), \ + patch.object(app_module, 'peer_registry', mock_pr), \ + patch.object(app_module, 'config_manager', mock_cm): + resp = client.get('/api/status') + data = json.loads(resp.data) + for key in ('cell_name', 'domain', 'uptime', 'peers_count', 'services'): + assert key in data, f"Missing key: {key}" diff --git a/tests/test_calendar_manager.py b/tests/test_calendar_manager.py index e334742..c0798e6 100644 --- a/tests/test_calendar_manager.py +++ b/tests/test_calendar_manager.py @@ -1,77 +1,435 @@ -import sys -from pathlib import Path - -# Add api directory to path -api_dir = Path(__file__).parent.parent / 'api' -sys.path.insert(0, str(api_dir)) -import unittest -import tempfile -import shutil -import os -from unittest.mock import patch -from calendar_manager import CalendarManager - -class TestCalendarManager(unittest.TestCase): - def setUp(self): - self.test_dir = tempfile.mkdtemp() - self.data_dir = os.path.join(self.test_dir, 'data') - self.config_dir = os.path.join(self.test_dir, 'config') - os.makedirs(self.data_dir, exist_ok=True) - os.makedirs(self.config_dir, exist_ok=True) - self.manager = CalendarManager(data_dir=self.data_dir, config_dir=self.config_dir) - - def tearDown(self): - shutil.rmtree(self.test_dir) - - def test_initialization(self): - self.assertTrue(os.path.exists(self.manager.calendar_dir)) - self.assertTrue(os.path.exists(self.manager.radicale_dir)) - - def test_ensure_config_exists(self): - config_file = os.path.join(self.manager.radicale_dir, 'config') - if os.path.exists(config_file): - os.remove(config_file) - self.manager._ensure_config_exists() - self.assertTrue(os.path.exists(config_file)) - - def test_generate_radicale_config(self): - config_file = os.path.join(self.manager.radicale_dir, 'config') - if os.path.exists(config_file): - os.remove(config_file) - self.manager._generate_radicale_config() - self.assertTrue(os.path.exists(config_file)) - with open(config_file) as f: - content = f.read() - self.assertIn('[server]', content) - self.assertIn('hosts = 0.0.0.0:5232', content) - - def test_get_status(self): - status = self.manager.get_status() - self.assertIsInstance(status, dict) - self.assertIn('status', status) - - @patch.object(CalendarManager, 'create_calendar', return_value=True) - @patch.object(CalendarManager, 'remove_calendar', return_value=True) - def test_create_and_remove_calendar(self, mock_remove, mock_create): - result = self.manager.create_calendar('testuser', 'testcal') - self.assertTrue(result) - result = self.manager.remove_calendar('testuser', 'testcal') - self.assertTrue(result) - - @patch.object(CalendarManager, 'add_event', return_value=True) - @patch.object(CalendarManager, 'remove_event', return_value=True) - def test_add_and_remove_event(self, mock_remove, mock_add): - result = self.manager.add_event('testuser', 'testcal', {'summary': 'Test'}) - self.assertTrue(result) - result = self.manager.remove_event('testuser', 'testcal', 'dummyuid') - self.assertTrue(result) - - def test_error_handling(self): - # Force errors by passing invalid arguments, should return False - self.assertFalse(self.manager.create_calendar(None, None)) - self.assertFalse(self.manager.add_event(None, None, None)) - self.assertFalse(self.manager.remove_calendar(None, None)) - self.assertFalse(self.manager.remove_event(None, None, None)) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file +import sys +from pathlib import Path + +# Add api directory to path +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) +import unittest +import tempfile +import shutil +import os +import json +from unittest.mock import patch, MagicMock +from calendar_manager import CalendarManager + +class TestCalendarManager(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.data_dir = os.path.join(self.test_dir, 'data') + self.config_dir = os.path.join(self.test_dir, 'config') + os.makedirs(self.data_dir, exist_ok=True) + os.makedirs(self.config_dir, exist_ok=True) + self.manager = CalendarManager(data_dir=self.data_dir, config_dir=self.config_dir) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def test_initialization(self): + self.assertTrue(os.path.exists(self.manager.calendar_dir)) + self.assertTrue(os.path.exists(self.manager.radicale_dir)) + + def test_ensure_config_exists(self): + config_file = os.path.join(self.manager.radicale_dir, 'config') + if os.path.exists(config_file): + os.remove(config_file) + self.manager._ensure_config_exists() + self.assertTrue(os.path.exists(config_file)) + + def test_generate_radicale_config(self): + config_file = os.path.join(self.manager.radicale_dir, 'config') + if os.path.exists(config_file): + os.remove(config_file) + self.manager._generate_radicale_config() + self.assertTrue(os.path.exists(config_file)) + with open(config_file) as f: + content = f.read() + self.assertIn('[server]', content) + self.assertIn('hosts = 0.0.0.0:5232', content) + + def test_get_status(self): + status = self.manager.get_status() + self.assertIsInstance(status, dict) + self.assertIn('status', status) + + @patch.object(CalendarManager, 'create_calendar', return_value=True) + @patch.object(CalendarManager, 'remove_calendar', return_value=True) + def test_create_and_remove_calendar(self, mock_remove, mock_create): + result = self.manager.create_calendar('testuser', 'testcal') + self.assertTrue(result) + result = self.manager.remove_calendar('testuser', 'testcal') + self.assertTrue(result) + + @patch.object(CalendarManager, 'add_event', return_value=True) + @patch.object(CalendarManager, 'remove_event', return_value=True) + def test_add_and_remove_event(self, mock_remove, mock_add): + result = self.manager.add_event('testuser', 'testcal', {'summary': 'Test'}) + self.assertTrue(result) + result = self.manager.remove_event('testuser', 'testcal', 'dummyuid') + self.assertTrue(result) + + def test_error_handling(self): + # Force errors by passing invalid arguments, should return False + self.assertFalse(self.manager.create_calendar(None, None)) + self.assertFalse(self.manager.add_event(None, None, None)) + self.assertFalse(self.manager.remove_calendar(None, None)) + self.assertFalse(self.manager.remove_event(None, None, None)) + + # --- New tests below --- + + def test_create_calendar_user_creates_and_persists(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + result = self.manager.create_calendar_user('alice', 'password123') + self.assertTrue(result) + users = self.manager._load_users() + self.assertEqual(len(users), 1) + self.assertEqual(users[0]['username'], 'alice') + self.assertNotIn('password', users[0]) + + def test_create_calendar_user_duplicate_returns_false(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'password123') + result = self.manager.create_calendar_user('alice', 'other') + self.assertFalse(result) + users = self.manager._load_users() + self.assertEqual(len(users), 1) + + def test_create_calendar_user_creates_user_directory(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'password123') + user_dir = os.path.join(self.manager.calendar_data_dir, 'users', 'alice') + self.assertTrue(os.path.exists(user_dir)) + + def test_delete_calendar_user_removes_user(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'password123') + with patch.object(self.manager, '_sync_users_to_cell_config'): + result = self.manager.delete_calendar_user('alice') + self.assertTrue(result) + users = self.manager._load_users() + self.assertEqual(len(users), 0) + + def test_delete_calendar_user_nonexistent_returns_false(self): + result = self.manager.delete_calendar_user('nobody') + self.assertFalse(result) + + def test_delete_calendar_user_removes_directory(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'password123') + user_dir = os.path.join(self.manager.calendar_data_dir, 'users', 'alice') + self.assertTrue(os.path.exists(user_dir)) + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.delete_calendar_user('alice') + self.assertFalse(os.path.exists(user_dir)) + + def test_get_calendar_users_empty(self): + users = self.manager.get_calendar_users() + self.assertEqual(users, []) + + def test_get_calendar_users_returns_created(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'pass') + self.manager.create_calendar_user('bob', 'pass') + users = self.manager.get_calendar_users() + self.assertEqual(len(users), 2) + usernames = [u['username'] for u in users] + self.assertIn('alice', usernames) + self.assertIn('bob', usernames) + + def test_create_calendar_real_persists(self): + result = self.manager.create_calendar('alice', 'personal') + self.assertTrue(result) + calendars = self.manager._load_calendars() + self.assertEqual(len(calendars), 1) + cal = calendars[0] + self.assertEqual(cal['username'], 'alice') + self.assertEqual(cal['name'], 'personal') + + def test_create_calendar_duplicate_returns_false(self): + self.manager.create_calendar('alice', 'personal') + result = self.manager.create_calendar('alice', 'personal') + self.assertFalse(result) + + def test_create_calendar_with_description_and_color(self): + result = self.manager.create_calendar('alice', 'work', description='Work stuff', color='#ff0000') + self.assertTrue(result) + calendars = self.manager._load_calendars() + cal = calendars[0] + self.assertEqual(cal['description'], 'Work stuff') + self.assertEqual(cal['color'], '#ff0000') + + def test_create_calendar_updates_user_count(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'pass') + self.manager.create_calendar('alice', 'personal') + users = self.manager._load_users() + alice = next(u for u in users if u['username'] == 'alice') + self.assertEqual(alice['calendars_count'], 1) + + def test_remove_calendar_real_removes(self): + self.manager.create_calendar('alice', 'personal') + result = self.manager.remove_calendar('alice', 'personal') + self.assertTrue(result) + calendars = self.manager._load_calendars() + self.assertEqual(len(calendars), 0) + + def test_remove_calendar_nonexistent_returns_true(self): + """Removing a non-existent calendar is idempotent (returns True).""" + result = self.manager.remove_calendar('alice', 'nonexistent') + self.assertTrue(result) + + def test_add_event_real_persists(self): + result = self.manager.add_event('alice', 'personal', {'summary': 'Meeting'}) + self.assertTrue(result) + events = self.manager._load_events() + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['summary'], 'Meeting') + self.assertEqual(events[0]['username'], 'alice') + self.assertEqual(events[0]['calendar'], 'personal') + + def test_add_event_assigns_uid_if_missing(self): + self.manager.add_event('alice', 'personal', {'summary': 'Test'}) + events = self.manager._load_events() + self.assertIn('uid', events[0]) + + def test_add_event_preserves_existing_uid(self): + self.manager.add_event('alice', 'personal', {'summary': 'Test', 'uid': 'my-uid-123'}) + events = self.manager._load_events() + self.assertEqual(events[0]['uid'], 'my-uid-123') + + def test_remove_event_real_removes_by_uid(self): + self.manager.add_event('alice', 'personal', {'summary': 'Test', 'uid': 'uid-1'}) + result = self.manager.remove_event('alice', 'personal', 'uid-1') + self.assertTrue(result) + events = self.manager._load_events() + self.assertEqual(len(events), 0) + + def test_remove_event_does_not_remove_wrong_uid(self): + self.manager.add_event('alice', 'personal', {'summary': 'Test', 'uid': 'uid-1'}) + self.manager.add_event('alice', 'personal', {'summary': 'Other', 'uid': 'uid-2'}) + self.manager.remove_event('alice', 'personal', 'uid-1') + events = self.manager._load_events() + self.assertEqual(len(events), 1) + self.assertEqual(events[0]['uid'], 'uid-2') + + def test_create_calendar_event_persists(self): + result = self.manager.create_calendar_event( + 'alice', 'personal', 'Team meeting', + '2026-01-01T09:00:00', '2026-01-01T10:00:00', + description='Weekly sync', location='Office') + self.assertTrue(result) + events = self.manager._load_events() + self.assertEqual(len(events), 1) + ev = events[0] + self.assertEqual(ev['title'], 'Team meeting') + self.assertEqual(ev['username'], 'alice') + + def test_create_calendar_event_updates_calendar_count(self): + self.manager.create_calendar('alice', 'personal') + self.manager.create_calendar_event( + 'alice', 'personal', 'Sync', + '2026-01-01T09:00:00', '2026-01-01T10:00:00') + calendars = self.manager._load_calendars() + self.assertEqual(calendars[0]['events_count'], 1) + + def test_get_calendar_events_filters_by_user_and_calendar(self): + self.manager.create_calendar_event( + 'alice', 'personal', 'Alice event', '2026-01-01T09:00', '2026-01-01T10:00') + self.manager.create_calendar_event( + 'bob', 'personal', 'Bob event', '2026-01-01T09:00', '2026-01-01T10:00') + alice_events = self.manager.get_calendar_events('alice', 'personal') + self.assertEqual(len(alice_events), 1) + self.assertEqual(alice_events[0]['title'], 'Alice event') + + def test_get_calendar_events_date_filter(self): + self.manager.create_calendar_event( + 'alice', 'personal', 'Jan event', '2026-01-15T09:00', '2026-01-15T10:00') + self.manager.create_calendar_event( + 'alice', 'personal', 'Feb event', '2026-02-15T09:00', '2026-02-15T10:00') + filtered = self.manager.get_calendar_events( + 'alice', 'personal', start_date='2026-01-01', end_date='2026-01-31') + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0]['title'], 'Jan event') + + def test_get_calendar_status_returns_users(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'pass') + status = self.manager.get_calendar_status() + self.assertIn('users', status) + self.assertEqual(len(status['users']), 1) + self.assertEqual(status['users'][0]['username'], 'alice') + + def test_get_metrics_empty(self): + with patch.object(self.manager, '_check_calendar_status', return_value=False): + metrics = self.manager.get_metrics() + self.assertIn('users_count', metrics) + self.assertIn('calendars_count', metrics) + self.assertIn('events_count', metrics) + self.assertEqual(metrics['users_count'], 0) + + def test_get_metrics_with_data(self): + with patch.object(self.manager, '_sync_users_to_cell_config'): + self.manager.create_calendar_user('alice', 'pass') + self.manager.create_calendar('alice', 'personal') + self.manager.add_event('alice', 'personal', {'summary': 'Evt'}) + with patch.object(self.manager, '_check_calendar_status', return_value=True): + metrics = self.manager.get_metrics() + self.assertEqual(metrics['users_count'], 1) + self.assertEqual(metrics['calendars_count'], 1) + self.assertEqual(metrics['events_count'], 1) + + def test_apply_config_no_port_key(self): + result = self.manager.apply_config({}) + self.assertEqual(result['restarted'], []) + + def test_apply_config_updates_radicale_hosts(self): + # Generate config first + self.manager._generate_radicale_config() + result = self.manager.apply_config({'port': 5233}) + self.assertEqual(result['restarted'], []) + config_file = os.path.join(self.manager.radicale_dir, 'config') + with open(config_file) as f: + content = f.read() + self.assertIn('hosts = 0.0.0.0:5233', content) + + def test_apply_config_no_radicale_file_is_safe(self): + """apply_config doesn't crash if radicale config file is missing.""" + config_file = os.path.join(self.manager.radicale_dir, 'config') + if os.path.exists(config_file): + os.remove(config_file) + result = self.manager.apply_config({'port': 5234}) + # Should not raise; warnings list may or may not be empty + self.assertIn('warnings', result) + + def test_write_radicale_htpasswd_creates_entry(self): + """_write_radicale_htpasswd writes a bcrypt entry for the user.""" + htpasswd = self.manager._radicale_htpasswd_path() + os.makedirs(os.path.dirname(htpasswd), exist_ok=True) + self.manager._write_radicale_htpasswd('alice', 'mypassword') + self.assertTrue(os.path.exists(htpasswd)) + with open(htpasswd) as f: + content = f.read() + self.assertIn('alice:', content) + + def test_write_radicale_htpasswd_updates_existing_entry(self): + """_write_radicale_htpasswd replaces a user's old entry.""" + htpasswd = self.manager._radicale_htpasswd_path() + os.makedirs(os.path.dirname(htpasswd), exist_ok=True) + self.manager._write_radicale_htpasswd('alice', 'pass1') + self.manager._write_radicale_htpasswd('alice', 'pass2') + with open(htpasswd) as f: + lines = f.readlines() + alice_lines = [l for l in lines if l.startswith('alice:')] + self.assertEqual(len(alice_lines), 1) + + def test_remove_radicale_htpasswd_removes_entry(self): + htpasswd = self.manager._radicale_htpasswd_path() + os.makedirs(os.path.dirname(htpasswd), exist_ok=True) + self.manager._write_radicale_htpasswd('alice', 'pass') + self.manager._write_radicale_htpasswd('bob', 'pass') + self.manager._remove_radicale_htpasswd('alice') + with open(htpasswd) as f: + content = f.read() + self.assertNotIn('alice:', content) + self.assertIn('bob:', content) + + def test_remove_radicale_htpasswd_no_file_is_safe(self): + """_remove_radicale_htpasswd doesn't raise when the file doesn't exist.""" + htpasswd = self.manager._radicale_htpasswd_path() + if os.path.exists(htpasswd): + os.remove(htpasswd) + self.manager._remove_radicale_htpasswd('alice') # should not raise + + def test_write_radicale_htpasswd_no_config_dir_is_safe(self): + """_write_radicale_htpasswd is a no-op when the config dir doesn't exist.""" + # Don't create the config dir + self.manager._write_radicale_htpasswd('alice', 'pass') + htpasswd = self.manager._radicale_htpasswd_path() + self.assertFalse(os.path.exists(htpasswd)) + + def test_test_database_connectivity_with_accessible_dir(self): + result = self.manager._test_database_connectivity() + self.assertIn('success', result) + self.assertTrue(result['success']) + + def test_test_service_connectivity_unreachable(self): + """_test_service_connectivity returns failure when cell-radicale isn't reachable.""" + result = self.manager._test_service_connectivity() + self.assertIn('success', result) + # In test environment Radicale is not running, so should be False + self.assertFalse(result['success']) + + def test_test_web_interface_unreachable(self): + result = self.manager._test_web_interface() + self.assertIn('success', result) + self.assertFalse(result['success']) + + def test_restart_service_calls_container(self): + with patch.object(self.manager, '_restart_container', return_value=True) as mock_restart: + result = self.manager.restart_service() + self.assertTrue(result) + mock_restart.assert_called_once_with('cell-radicale') + + def test_restart_service_failure_returns_false(self): + with patch.object(self.manager, '_restart_container', return_value=False): + result = self.manager.restart_service() + self.assertFalse(result) + + def test_sync_users_to_cell_config_best_effort(self): + """_sync_users_to_cell_config failure is non-fatal.""" + with patch('config_manager.ConfigManager', side_effect=Exception('no config')): + # Should not raise + self.manager._sync_users_to_cell_config() + + def test_check_calendar_status_returns_bool(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0, stdout=':5232 LISTEN') + result = self.manager._check_calendar_status() + self.assertIsInstance(result, bool) + + def test_check_calendar_status_false_when_no_port(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0, stdout='no matching port') + result = self.manager._check_calendar_status() + self.assertFalse(result) + + def test_load_users_returns_empty_on_missing_file(self): + users = self.manager._load_users() + self.assertEqual(users, []) + + def test_load_calendars_returns_empty_on_missing_file(self): + calendars = self.manager._load_calendars() + self.assertEqual(calendars, []) + + def test_load_events_returns_empty_on_missing_file(self): + events = self.manager._load_events() + self.assertEqual(events, []) + + def test_load_users_handles_corrupt_file(self): + with open(self.manager.users_file, 'w') as f: + f.write('{corrupt') + users = self.manager._load_users() + self.assertEqual(users, []) + + def test_get_configured_port_default(self): + port = self.manager._get_configured_port() + self.assertEqual(port, 5232) + + def test_get_configured_port_from_config(self): + with patch.object(self.manager, 'get_config', return_value={'port': 5555}): + port = self.manager._get_configured_port() + self.assertEqual(port, 5555) + + def test_test_connectivity_returns_dict(self): + with patch.object(self.manager, '_test_service_connectivity', return_value={'success': False, 'message': ''}): + with patch.object(self.manager, '_test_database_connectivity', return_value={'success': True, 'message': ''}): + with patch.object(self.manager, '_test_web_interface', return_value={'success': False, 'message': ''}): + result = self.manager.test_connectivity() + self.assertIn('service_connectivity', result) + self.assertIn('database_connectivity', result) + self.assertIn('web_interface', result) + self.assertIn('success', result) + self.assertFalse(result['success']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_cell_cli_extra.py b/tests/test_cell_cli_extra.py new file mode 100644 index 0000000..038f73f --- /dev/null +++ b/tests/test_cell_cli_extra.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +""" +Additional tests for cell_cli.py covering the functions NOT in test_cli_tool.py: +- list_peers (error path) +- list_nat_rules / add_nat_rule / delete_nat_rule +- list_peer_routes / add_peer_route / delete_peer_route +- list_firewall_rules / add_firewall_rule / delete_firewall_rule +- show_services_status +- list_wireguard_peers +- show_network_info / show_dns_status / show_ntp_status +- main() command routing +""" + +import sys +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from cell_cli import ( + list_peers, add_peer, remove_peer, show_config, update_config, + list_nat_rules, add_nat_rule, delete_nat_rule, + list_peer_routes, add_peer_route, delete_peer_route, + list_firewall_rules, add_firewall_rule, delete_firewall_rule, + show_services_status, list_wireguard_peers, + show_network_info, show_dns_status, show_ntp_status, +) + + +class TestListPeersErrorPath(unittest.TestCase): + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_list_peers_failure_prints_error(self, mock_req, mock_print): + list_peers() + mock_print.assert_any_call('Failed to fetch peers.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=[]) + def test_list_peers_empty_list(self, mock_req, mock_print): + list_peers() + mock_print.assert_any_call('No peers configured.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=[ + {'name': 'alice', 'ip': '10.0.0.2', + 'public_key': 'abcdefghijklmnopqrstuvwxyz', 'added_at': '2026-01-01'} + ]) + def test_list_peers_shows_peer_info(self, mock_req, mock_print): + list_peers() + self.assertTrue(any('alice' in str(c) for c in mock_print.call_args_list)) + + +class TestNatRules(unittest.TestCase): + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'nat_rules': []}) + def test_list_nat_rules_empty(self, mock_req, mock_print): + list_nat_rules() + mock_print.assert_any_call('No NAT rules configured.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'nat_rules': [ + {'id': 1, 'source_network': '10.0.0.0/24', 'target_interface': 'eth0', + 'masquerade': True, 'nat_type': 'MASQUERADE', 'protocol': 'ALL', + 'external_port': '', 'internal_ip': '', 'internal_port': ''} + ]}) + def test_list_nat_rules_shows_rules(self, mock_req, mock_print): + list_nat_rules() + self.assertTrue(any('eth0' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_list_nat_rules_failure(self, mock_req, mock_print): + list_nat_rules() + mock_print.assert_any_call('Failed to fetch NAT rules.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'id': 1}) + def test_add_nat_rule_success(self, mock_req, mock_print): + add_nat_rule('10.0.0.0/24', 'eth0', True, 'MASQUERADE', 'ALL', '', '', '') + mock_print.assert_any_call('✅ NAT rule added.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_add_nat_rule_failure(self, mock_req, mock_print): + add_nat_rule('10.0.0.0/24', 'eth0', False, 'DNAT', 'TCP', '80', '10.0.0.5', '8080') + mock_print.assert_any_call('❌ Failed to add NAT rule.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'ok': True}) + def test_delete_nat_rule_success(self, mock_req, mock_print): + delete_nat_rule(1) + mock_print.assert_any_call('✅ NAT rule deleted.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_delete_nat_rule_failure(self, mock_req, mock_print): + delete_nat_rule(99) + mock_print.assert_any_call('❌ Failed to delete NAT rule.') + + +class TestPeerRoutes(unittest.TestCase): + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'peer_routes': []}) + def test_list_peer_routes_empty(self, mock_req, mock_print): + list_peer_routes() + mock_print.assert_any_call('No peer routes configured.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'peer_routes': [ + {'peer_name': 'alice', 'peer_ip': '10.0.0.2', + 'allowed_networks': ['192.168.1.0/24'], 'route_type': 'split'} + ]}) + def test_list_peer_routes_shows_routes(self, mock_req, mock_print): + list_peer_routes() + self.assertTrue(any('alice' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_list_peer_routes_failure(self, mock_req, mock_print): + list_peer_routes() + mock_print.assert_any_call('Failed to fetch peer routes.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'ok': True}) + def test_add_peer_route_success(self, mock_req, mock_print): + add_peer_route('alice', '10.0.0.2', '192.168.1.0/24', 'split') + mock_print.assert_any_call('✅ Peer route added.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_add_peer_route_failure(self, mock_req, mock_print): + add_peer_route('alice', '10.0.0.2', '192.168.1.0/24', 'split') + mock_print.assert_any_call('❌ Failed to add peer route.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'ok': True}) + def test_delete_peer_route_success(self, mock_req, mock_print): + delete_peer_route('alice') + mock_print.assert_any_call('✅ Peer route deleted.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_delete_peer_route_failure(self, mock_req, mock_print): + delete_peer_route('alice') + mock_print.assert_any_call('❌ Failed to delete peer route.') + + +class TestFirewallRules(unittest.TestCase): + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'firewall_rules': []}) + def test_list_firewall_rules_empty(self, mock_req, mock_print): + list_firewall_rules() + mock_print.assert_any_call('No firewall rules configured.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'firewall_rules': [ + {'id': 1, 'rule_type': 'ACCEPT', 'source': '10.0.0.0/24', + 'destination': 'any', 'protocol': 'TCP', 'port_range': '80', 'action': 'ACCEPT'} + ]}) + def test_list_firewall_rules_shows_rules(self, mock_req, mock_print): + list_firewall_rules() + self.assertTrue(any('ACCEPT' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_list_firewall_rules_failure(self, mock_req, mock_print): + list_firewall_rules() + mock_print.assert_any_call('Failed to fetch firewall rules.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'id': 1}) + def test_add_firewall_rule_success(self, mock_req, mock_print): + add_firewall_rule('ACCEPT', '10.0.0.0/24', 'any', 'ACCEPT', 'TCP', '80') + mock_print.assert_any_call('✅ Firewall rule added.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_add_firewall_rule_failure(self, mock_req, mock_print): + add_firewall_rule('DROP', 'any', 'any', 'DROP', 'ALL', '') + mock_print.assert_any_call('❌ Failed to add firewall rule.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'ok': True}) + def test_delete_firewall_rule_success(self, mock_req, mock_print): + delete_firewall_rule(1) + mock_print.assert_any_call('✅ Firewall rule deleted.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_delete_firewall_rule_failure(self, mock_req, mock_print): + delete_firewall_rule(99) + mock_print.assert_any_call('❌ Failed to delete firewall rule.') + + +class TestShowServicesStatus(unittest.TestCase): + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={ + 'email': {'status': 'online', 'running': True}, + 'dns': True + }) + def test_show_services_status_with_dict_and_bool(self, mock_req, mock_print): + show_services_status() + self.assertTrue(any('email' in str(c) for c in mock_print.call_args_list)) + self.assertTrue(any('dns' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_show_services_status_failure(self, mock_req, mock_print): + show_services_status() + mock_print.assert_any_call('Failed to fetch service status.') + + +class TestListWireguardPeers(unittest.TestCase): + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=[ + {'name': 'alice', 'public_key': 'pk1', 'ip': '10.0.0.2', 'status': 'active'} + ]) + def test_list_wireguard_peers_shows_peers(self, mock_req, mock_print): + list_wireguard_peers() + self.assertTrue(any('alice' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_list_wireguard_peers_failure(self, mock_req, mock_print): + list_wireguard_peers() + mock_print.assert_any_call('Failed to fetch WireGuard peers.') + + +class TestNetworkDnsNtpStatus(unittest.TestCase): + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'gateway': '192.168.1.1', 'subnet': '10.0.0.0/24'}) + def test_show_network_info_success(self, mock_req, mock_print): + show_network_info() + self.assertTrue(any('gateway' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_show_network_info_failure(self, mock_req, mock_print): + show_network_info() + mock_print.assert_any_call('Failed to fetch network info.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'running': True, 'port': 53}) + def test_show_dns_status_success(self, mock_req, mock_print): + show_dns_status() + self.assertTrue(any('running' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_show_dns_status_failure(self, mock_req, mock_print): + show_dns_status() + mock_print.assert_any_call('Failed to fetch DNS status.') + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value={'synced': True, 'server': 'pool.ntp.org'}) + def test_show_ntp_status_success(self, mock_req, mock_print): + show_ntp_status() + self.assertTrue(any('synced' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + @patch('cell_cli.api_request', return_value=None) + def test_show_ntp_status_failure(self, mock_req, mock_print): + show_ntp_status() + mock_print.assert_any_call('Failed to fetch NTP status.') + + +class TestMainFunction(unittest.TestCase): + """Cover main() by patching individual functions and simulating command dispatch.""" + + def _run_main(self, args): + import sys as _sys + from cell_cli import main + old_argv = _sys.argv + _sys.argv = ['cell_cli'] + args + try: + with patch('builtins.print'): + try: + main() + except SystemExit: + pass + finally: + _sys.argv = old_argv + + def test_main_status_command(self): + with patch('cell_cli.show_status') as mock_fn: + self._run_main(['status']) + mock_fn.assert_called_once() + + def test_main_peers_list_command(self): + with patch('cell_cli.list_peers') as mock_fn: + self._run_main(['peers', 'list']) + mock_fn.assert_called_once() + + def test_main_peers_add_command(self): + with patch('cell_cli.add_peer') as mock_fn: + self._run_main(['peers', 'add', 'alice', '10.0.0.2', 'pubkey']) + mock_fn.assert_called_once_with('alice', '10.0.0.2', 'pubkey') + + def test_main_peers_remove_command(self): + with patch('cell_cli.remove_peer') as mock_fn: + self._run_main(['peers', 'remove', 'alice']) + mock_fn.assert_called_once_with('alice') + + def test_main_config_show_command(self): + with patch('cell_cli.show_config') as mock_fn: + self._run_main(['config', 'show']) + mock_fn.assert_called_once() + + def test_main_config_update_command(self): + with patch('cell_cli.update_config') as mock_fn: + self._run_main(['config', 'update', 'cell_name', 'mycell']) + mock_fn.assert_called_once_with('cell_name', 'mycell') + + def test_main_routing_nat_list(self): + with patch('cell_cli.list_nat_rules') as mock_fn: + self._run_main(['routing', 'nat', 'list']) + mock_fn.assert_called_once() + + def test_main_routing_nat_add(self): + with patch('cell_cli.add_nat_rule') as mock_fn: + self._run_main(['routing', 'nat', 'add', '10.0.0.0/24', 'eth0']) + mock_fn.assert_called_once() + + def test_main_routing_nat_delete(self): + with patch('cell_cli.delete_nat_rule') as mock_fn: + self._run_main(['routing', 'nat', 'delete', '1']) + mock_fn.assert_called_once_with('1') # argparse passes as string + + def test_main_routing_peers_list(self): + with patch('cell_cli.list_peer_routes') as mock_fn: + self._run_main(['routing', 'peers', 'list']) + mock_fn.assert_called_once() + + def test_main_routing_peers_add(self): + with patch('cell_cli.add_peer_route') as mock_fn: + self._run_main(['routing', 'peers', 'add', 'alice', '10.0.0.2', + '192.168.1.0/24']) + mock_fn.assert_called_once() + + def test_main_routing_peers_delete(self): + with patch('cell_cli.delete_peer_route') as mock_fn: + self._run_main(['routing', 'peers', 'delete', 'alice']) + mock_fn.assert_called_once_with('alice') + + def test_main_routing_firewall_list(self): + with patch('cell_cli.list_firewall_rules') as mock_fn: + self._run_main(['routing', 'firewall', 'list']) + mock_fn.assert_called_once() + + def test_main_routing_firewall_add(self): + with patch('cell_cli.add_firewall_rule') as mock_fn: + self._run_main(['routing', 'firewall', 'add', + 'ACCEPT', '10.0.0.0/24', 'any', 'ACCEPT']) + mock_fn.assert_called_once() + + def test_main_routing_firewall_delete(self): + with patch('cell_cli.delete_firewall_rule') as mock_fn: + self._run_main(['routing', 'firewall', 'delete', '1']) + mock_fn.assert_called_once_with('1') + + def test_main_services_status_command(self): + with patch('cell_cli.show_services_status') as mock_fn: + self._run_main(['services-status']) + mock_fn.assert_called_once() + + def test_main_wireguard_list_command(self): + with patch('cell_cli.list_wireguard_peers') as mock_fn: + self._run_main(['wireguard-peers']) + mock_fn.assert_called_once() + + def test_main_network_info_command(self): + with patch('cell_cli.show_network_info') as mock_fn: + self._run_main(['network-info']) + mock_fn.assert_called_once() + + def test_main_dns_status_command(self): + with patch('cell_cli.show_dns_status') as mock_fn: + self._run_main(['dns-status']) + mock_fn.assert_called_once() + + def test_main_ntp_status_command(self): + with patch('cell_cli.show_ntp_status') as mock_fn: + self._run_main(['ntp-status']) + mock_fn.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_config_manager.py b/tests/test_config_manager.py index 18e45a7..b440cec 100644 --- a/tests/test_config_manager.py +++ b/tests/test_config_manager.py @@ -70,18 +70,16 @@ class TestConfigManager(unittest.TestCase): # Test valid config valid_config = { 'dns_port': 53, - 'dhcp_range': '10.0.0.100-10.0.0.200', 'ntp_servers': ['pool.ntp.org'] } validation = self.config_manager.validate_config('network', valid_config) self.assertTrue(validation['valid']) self.assertEqual(len(validation['errors']), 0) - + # Test invalid config (missing required field) invalid_config = { - 'dns_port': 53, - 'ntp_servers': ['pool.ntp.org'] - # Missing dhcp_range + 'dns_port': 53 + # Missing ntp_servers } validation = self.config_manager.validate_config('network', invalid_config) self.assertFalse(validation['valid']) @@ -387,12 +385,9 @@ class TestNetworkManagerApply(unittest.TestCase): self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(os.path.join(self.data_dir, 'dns'), exist_ok=True) - os.makedirs(os.path.join(self.config_dir, 'dhcp'), exist_ok=True) os.makedirs(os.path.join(self.config_dir, 'ntp'), exist_ok=True) # Seed minimal config files - with open(os.path.join(self.config_dir, 'dhcp', 'dnsmasq.conf'), 'w') as f: - f.write('dhcp-range=10.0.0.100,10.0.0.200,12h\ndomain=cell\n') with open(os.path.join(self.config_dir, 'ntp', 'chrony.conf'), 'w') as f: f.write('server time.google.com iburst\nserver pool.ntp.org iburst\n') @@ -403,14 +398,6 @@ class TestNetworkManagerApply(unittest.TestCase): def tearDown(self): shutil.rmtree(self.test_dir) - @patch('subprocess.run') - def test_apply_config_writes_dhcp_range(self, mock_run): - mock_run.return_value = MagicMock(returncode=0) - result = self.nm.apply_config({'dhcp_range': '192.168.1.100,192.168.1.200,24h'}) - dhcp_conf = open(os.path.join(self.config_dir, 'dhcp', 'dnsmasq.conf')).read() - self.assertIn('192.168.1.100,192.168.1.200,24h', dhcp_conf) - self.assertIn('cell-dhcp', ' '.join(result['restarted'])) - @patch('subprocess.run') def test_apply_config_writes_ntp_servers(self, mock_run): mock_run.return_value = MagicMock(returncode=0) @@ -422,14 +409,6 @@ class TestNetworkManagerApply(unittest.TestCase): self.assertNotIn('time.google.com', ntp_conf) self.assertIn('cell-ntp', result['restarted']) - @patch('subprocess.run') - def test_apply_domain_updates_dnsmasq(self, mock_run): - mock_run.return_value = MagicMock(returncode=0) - result = self.nm.apply_domain('newdomain.local') - dhcp_conf = open(os.path.join(self.config_dir, 'dhcp', 'dnsmasq.conf')).read() - self.assertIn('domain=newdomain.local', dhcp_conf) - self.assertNotIn('domain=cell', dhcp_conf) - @patch('subprocess.run') def test_apply_domain_updates_corefile(self, mock_run): """apply_domain must rewrite the Corefile zone name and reload CoreDNS.""" @@ -462,10 +441,7 @@ class TestNetworkManagerApplyCellName(unittest.TestCase): self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(os.path.join(self.data_dir, 'dns'), exist_ok=True) - os.makedirs(os.path.join(self.config_dir, 'dhcp'), exist_ok=True) os.makedirs(os.path.join(self.config_dir, 'ntp'), exist_ok=True) - with open(os.path.join(self.config_dir, 'dhcp', 'dnsmasq.conf'), 'w') as f: - f.write('domain=cell\n') with open(os.path.join(self.config_dir, 'ntp', 'chrony.conf'), 'w') as f: f.write('server pool.ntp.org iburst\n') # Create a zone file matching _generate_zone_content format (name TTL IN type value) diff --git a/tests/test_config_manager_extra.py b/tests/test_config_manager_extra.py new file mode 100644 index 0000000..9e240b7 --- /dev/null +++ b/tests/test_config_manager_extra.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +""" +Additional tests for ConfigManager covering untested utility methods: +- set_identity_field +- get_installed_services / set_installed_service / remove_installed_service +- get_connectivity_config / set_connectivity_field +- set_ddns_config / get_ddns_token / set_ddns_token +- export_config yaml format +- import_config yaml format + selective services +- backup_config exception path (lines 424-426) +- restore_config selective restore (lines 441-453) +- _validate_vol_entry (unsafe container/path/name) +- _save_all_configs OSError path +""" + +import sys +import os +import json +import tempfile +import shutil +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from config_manager import ConfigManager + + +def _make_cm(tmp): + """Create a ConfigManager with temp dirs.""" + config_file = os.path.join(tmp, 'cell_config.json') + data_dir = os.path.join(tmp, 'data') + os.makedirs(data_dir, exist_ok=True) + return ConfigManager(config_file, data_dir) + + +class TestSetIdentityField(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_set_identity_field_persists(self): + self.cm.set_identity_field('cell_name', 'mycell') + self.assertEqual(self.cm.configs['_identity']['cell_name'], 'mycell') + + def test_set_identity_field_creates_identity_if_missing(self): + self.cm.configs.pop('_identity', None) + self.cm.set_identity_field('domain', 'cell') + self.assertIn('_identity', self.cm.configs) + self.assertEqual(self.cm.configs['_identity']['domain'], 'cell') + + +class TestInstalledServices(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_get_installed_services_empty_by_default(self): + result = self.cm.get_installed_services() + self.assertIsInstance(result, dict) + + def test_set_installed_service_stores_record(self): + self.cm.set_installed_service('gitea', {'version': '1.0', 'enabled': True}) + self.assertIn('gitea', self.cm.get_installed_services()) + + def test_remove_installed_service_removes_entry(self): + self.cm.set_installed_service('gitea', {'version': '1.0'}) + self.cm.remove_installed_service('gitea') + self.assertNotIn('gitea', self.cm.get_installed_services()) + + def test_remove_installed_service_not_present_does_not_raise(self): + # Should not raise even if service was never installed + self.cm.remove_installed_service('nonexistent') + + +class TestConnectivityConfig(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_get_connectivity_config_returns_dict_with_exits(self): + result = self.cm.get_connectivity_config() + self.assertIn('exits', result) + self.assertIn('peer_exit_map', result) + + def test_get_connectivity_config_initializes_missing(self): + self.cm.configs.pop('connectivity', None) + result = self.cm.get_connectivity_config() + self.assertIsInstance(result, dict) + self.assertIn('exits', result) + + def test_set_connectivity_field_returns_true_on_success(self): + result = self.cm.set_connectivity_field('exits', {'vpn1': {'host': '10.0.0.1'}}) + self.assertTrue(result) + self.assertIn('exits', self.cm.configs.get('connectivity', {})) + + def test_set_connectivity_field_returns_false_on_save_error(self): + with patch.object(self.cm, '_save_all_configs', side_effect=OSError('disk full')): + result = self.cm.set_connectivity_field('exits', {}) + self.assertFalse(result) + + +class TestDdnsConfig(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_set_ddns_config_strips_token(self): + self.cm.set_ddns_config({'hostname': 'pic.ngo', 'token': 'SECRET'}) + ddns = self.cm.configs.get('ddns', {}) + self.assertNotIn('token', ddns) + self.assertEqual(ddns.get('hostname'), 'pic.ngo') + + def test_set_ddns_token_writes_to_file(self): + self.cm.set_ddns_token('mytoken123') + token_path = self.cm._ddns_token_path + self.assertTrue(token_path.exists()) + self.assertEqual(token_path.read_text().strip(), 'mytoken123') + + def test_get_ddns_token_reads_from_file(self): + self.cm.set_ddns_token('readmetoken') + result = self.cm.get_ddns_token() + self.assertEqual(result, 'readmetoken') + + def test_get_ddns_token_migrates_from_configs(self): + # Legacy token stored in cell_config.json + self.cm.configs['ddns'] = {'hostname': 'pic.ngo', 'token': 'oldtoken'} + result = self.cm.get_ddns_token() + self.assertEqual(result, 'oldtoken') + # After migration, should be in file + self.assertTrue(self.cm._ddns_token_path.exists()) + + def test_set_ddns_token_oserror_does_not_raise(self): + with patch('builtins.open', side_effect=OSError('no space')): + with patch.object(Path, 'parent', new_callable=lambda: property(lambda self: Path(self.name).parent)): + # Just make sure no exception propagates + try: + self.cm.set_ddns_token('tok') + except Exception: + pass + + def test_set_ddns_token_removes_legacy_token_from_config(self): + self.cm.configs['ddns'] = {'hostname': 'pic.ngo', 'token': 'legacytok'} + self.cm.set_ddns_token('newtok') + # Legacy token should be removed from in-memory config + ddns = self.cm.configs.get('ddns', {}) + self.assertNotIn('token', ddns) + + +class TestExportImportConfigExtra(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_export_config_yaml_format(self): + self.cm.update_service_config('network', {'dns_port': 53}) + result = self.cm.export_config('yaml') + self.assertIn('network', result) + + def test_export_config_filters_by_services(self): + self.cm.update_service_config('network', {'dns_port': 53}) + self.cm.update_service_config('wireguard', {'port': 51820}) + result = self.cm.export_config('json', services=['network']) + data = json.loads(result) + self.assertIn('network', data) + self.assertNotIn('wireguard', data) + + def test_import_config_yaml_format(self): + yaml_data = 'network:\n dns_port: 53\n' + result = self.cm.import_config(yaml_data, 'yaml') + self.assertTrue(result) + + def test_import_config_filters_by_services(self): + data = json.dumps({'network': {'dns_port': 53}, 'wireguard': {'port': 51820}}) + result = self.cm.import_config(data, 'json', services=['network']) + self.assertTrue(result) + self.assertEqual(self.cm.configs.get('network', {}).get('dns_port'), 53) + + def test_import_config_unsupported_format_returns_false(self): + result = self.cm.import_config('', 'xml') + self.assertFalse(result) + + def test_import_config_with_identity(self): + data = json.dumps({'identity': {'cell_name': 'imported_cell'}}) + result = self.cm.import_config(data, 'json') + self.assertTrue(result) + self.assertEqual( + self.cm.configs.get('_identity', {}).get('cell_name'), + 'imported_cell' + ) + + +class TestBackupRestoreExtra(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_backup_config_exception_reraises(self): + # Force an exception by making shutil.copy2 raise after backup_path is created + with patch('shutil.copy2', side_effect=OSError('disk full')): + # backup_config reraises on exception + with self.assertRaises(Exception): + self.cm.backup_config() + + def test_restore_config_selective_services(self): + # Create a real backup first + backup_id = self.cm.backup_config() + # Change a config value then restore selectively + self.cm.configs.setdefault('network', {})['dns_port'] = 9999 + result = self.cm.restore_config(backup_id, services=['network']) + self.assertTrue(result) + + def test_restore_config_nonexistent_backup_returns_false(self): + result = self.cm.restore_config('backup_nonexistent_999') + self.assertFalse(result) + + +class TestSaveAllConfigsError(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_save_all_configs_permission_error_is_logged(self): + # Replace the config_file path with something that will fail to write + with patch('builtins.open', side_effect=PermissionError('no permission')): + # Should not raise + self.cm._save_all_configs() + + +class TestValidateVolEntry(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.cm = _make_cm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_valid_vol_entry_returns_true(self): + result = self.cm._validate_vol_entry('email', { + 'container': 'cell-mail', + 'path': '/data/mail', + 'name': 'mail_data' + }) + self.assertTrue(result) + + def test_unsafe_container_name_returns_false(self): + result = self.cm._validate_vol_entry('email', { + 'container': '../../../etc/passwd', + 'path': '/data', + 'name': 'safe_name' + }) + self.assertFalse(result) + + def test_unsafe_path_traversal_returns_false(self): + result = self.cm._validate_vol_entry('email', { + 'container': 'cell-mail', + 'path': '/data/../etc', + 'name': 'safe_name' + }) + self.assertFalse(result) + + def test_path_not_starting_with_slash_returns_false(self): + result = self.cm._validate_vol_entry('email', { + 'container': 'cell-mail', + 'path': 'relative/path', + 'name': 'safe_name' + }) + self.assertFalse(result) + + def test_unsafe_vol_name_returns_false(self): + result = self.cm._validate_vol_entry('email', { + 'container': 'cell-mail', + 'path': '/data/mail', + 'name': 'name with spaces!' + }) + self.assertFalse(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_connectivity_manager.py b/tests/test_connectivity_manager.py index 457a13c..a3fb21f 100644 --- a/tests/test_connectivity_manager.py +++ b/tests/test_connectivity_manager.py @@ -744,5 +744,124 @@ class TestApplyRoutes(unittest.TestCase): self.assertIsInstance(result['rules_applied'], int) +# --------------------------------------------------------------------------- +# _exit_status — status string + store-service bridge +# --------------------------------------------------------------------------- + +class TestExitStatus(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _mgr(self, installed=None): + config_manager = MagicMock() + config_manager.get_installed_services.return_value = installed or {} + return _make_manager(tmp_dir=self.tmp, config_manager=config_manager) + + def test_status_not_configured_when_nothing_present(self): + mgr = self._mgr() + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('wireguard_ext') + self.assertEqual(info['status'], 'not_configured') + self.assertFalse(info['configured']) + self.assertFalse(info['iface_up']) + + def test_status_configured_when_legacy_file_present(self): + mgr = self._mgr() + path = os.path.join(mgr.wireguard_ext_dir, 'wg_ext0.conf') + with open(path, 'w') as f: + f.write('[Interface]\nPrivateKey = abc\n') + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('wireguard_ext') + self.assertTrue(info['configured']) + self.assertEqual(info['status'], 'configured') + + def test_status_active_when_iface_up(self): + mgr = self._mgr() + path = os.path.join(mgr.wireguard_ext_dir, 'wg_ext0.conf') + with open(path, 'w') as f: + f.write('[Interface]\nPrivateKey = abc\n') + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock( + returncode=0, stdout='4: wg_ext0: ', stderr='' + ) + info = mgr._exit_status('wireguard_ext') + self.assertTrue(info['iface_up']) + self.assertEqual(info['status'], 'active') + + def test_store_installed_wireguard_ext_reports_configured(self): + mgr = self._mgr(installed={'wireguard-ext': {'manifest': {'id': 'wireguard-ext'}}}) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('wireguard_ext') + self.assertTrue(info['configured']) + self.assertEqual(info['status'], 'configured') + + def test_store_installed_openvpn_client_reports_configured(self): + mgr = self._mgr(installed={'openvpn-client': {'manifest': {'id': 'openvpn-client'}}}) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('openvpn') + self.assertTrue(info['configured']) + self.assertEqual(info['status'], 'configured') + + def test_unrelated_store_service_does_not_configure_exit(self): + mgr = self._mgr(installed={'email': {'manifest': {'id': 'email'}}}) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('wireguard_ext') + self.assertFalse(info['configured']) + self.assertEqual(info['status'], 'not_configured') + + def test_running_container_reports_configured(self): + mgr = self._mgr() + + def fake_run(cmd, **kwargs): + if 'inspect' in cmd: + return MagicMock(returncode=0, stdout='true\n', stderr='') + return MagicMock(returncode=1, stdout='', stderr='') + + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.side_effect = fake_run + info = mgr._exit_status('wireguard_ext') + self.assertTrue(info['configured']) + self.assertEqual(info['status'], 'configured') + + def test_stopped_container_does_not_configure_exit(self): + mgr = self._mgr() + + def fake_run(cmd, **kwargs): + if 'inspect' in cmd: + return MagicMock(returncode=0, stdout='false\n', stderr='') + return MagicMock(returncode=1, stdout='', stderr='') + + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.side_effect = fake_run + info = mgr._exit_status('wireguard_ext') + self.assertFalse(info['configured']) + + def test_list_exits_entries_have_status_string(self): + mgr = self._mgr() + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + exits = mgr.list_exits() + for item in exits: + self.assertIn('status', item) + self.assertIn(item['status'], ('active', 'configured', 'not_configured')) + + def test_tor_defaults_to_configured(self): + mgr = self._mgr() + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('tor') + self.assertTrue(info['configured']) + self.assertEqual(info['status'], 'configured') + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_connectivity_proxy.py b/tests/test_connectivity_proxy.py new file mode 100644 index 0000000..416f22a --- /dev/null +++ b/tests/test_connectivity_proxy.py @@ -0,0 +1,519 @@ +""" +Tests for the proxy (redsocks) exit type — configure_proxy validation, +redsocks.conf generation (golden strings, no injection), apply_routes +REDIRECT rules, _exit_status bridging, egress_manager mirroring, and the +/api/connectivity/exits/proxy route (never echoes secrets). +""" + +import os +import stat +import sys +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'api')) + +import connectivity_manager as cm_module +from connectivity_manager import ConnectivityManager +import egress_manager as em_module + +_SENTINEL = object() + + +def _make_manager(tmp_dir=None, peer_registry=_SENTINEL, config_manager=None): + if tmp_dir is None: + tmp_dir = tempfile.mkdtemp() + + if config_manager is None: + config_manager = MagicMock() + config_manager.get_identity.return_value = { + 'cell_name': 'test', + 'ip_range': '172.20.0.0/16', + } + config_manager.get_connectivity_config.return_value = { + 'exits': {}, 'peer_exit_map': {}, + } + config_manager.get_installed_services.return_value = {} + + if peer_registry is _SENTINEL: + peer_registry = MagicMock() + peer_registry.list_peers.return_value = [] + + with patch.object(ConnectivityManager, '_subscribe_to_events', lambda self: None): + mgr = ConnectivityManager( + config_manager=config_manager, + peer_registry=peer_registry, + data_dir=tmp_dir, + config_dir=tmp_dir, + ) + return mgr + + +def _valid_cfg(**overrides): + cfg = { + 'scheme': 'socks5', + 'host': 'proxy.example.com', + 'port': 1080, + } + cfg.update(overrides) + return cfg + + +def _mock_subprocess_ok(): + return MagicMock(returncode=0, stdout='', stderr='') + + +# --------------------------------------------------------------------------- +# configure_proxy — validation +# --------------------------------------------------------------------------- + +class TestConfigureProxyValidation(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.mgr = _make_manager(tmp_dir=self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_valid_socks5_config_returns_ok(self): + result = self.mgr.configure_proxy(_valid_cfg()) + self.assertTrue(result['ok'], result) + + def test_valid_http_config_returns_ok(self): + result = self.mgr.configure_proxy(_valid_cfg(scheme='http', port=3128)) + self.assertTrue(result['ok']) + + def test_non_dict_config_rejected(self): + result = self.mgr.configure_proxy([1, 2, 3]) + self.assertFalse(result['ok']) + + def test_missing_scheme_rejected(self): + cfg = _valid_cfg() + del cfg['scheme'] + result = self.mgr.configure_proxy(cfg) + self.assertFalse(result['ok']) + self.assertIn('scheme', result['error']) + + def test_invalid_scheme_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(scheme='socks4')) + self.assertFalse(result['ok']) + + def test_https_scheme_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(scheme='https')) + self.assertFalse(result['ok']) + + def test_missing_host_rejected(self): + cfg = _valid_cfg() + del cfg['host'] + result = self.mgr.configure_proxy(cfg) + self.assertFalse(result['ok']) + self.assertIn('host', result['error']) + + def test_host_with_semicolon_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(host='evil;injected')) + self.assertFalse(result['ok']) + + def test_host_with_quote_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(host='a"b')) + self.assertFalse(result['ok']) + + def test_host_with_newline_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(host='a\nb')) + self.assertFalse(result['ok']) + + def test_ip_host_accepted(self): + result = self.mgr.configure_proxy(_valid_cfg(host='203.0.113.99')) + self.assertTrue(result['ok']) + + def test_missing_port_rejected(self): + cfg = _valid_cfg() + del cfg['port'] + result = self.mgr.configure_proxy(cfg) + self.assertFalse(result['ok']) + self.assertIn('port', result['error']) + + def test_port_zero_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(port=0)) + self.assertFalse(result['ok']) + + def test_port_above_65535_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(port=65536)) + self.assertFalse(result['ok']) + + def test_port_non_numeric_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(port='oops')) + self.assertFalse(result['ok']) + + def test_user_with_injection_chars_rejected(self): + result = self.mgr.configure_proxy( + _valid_cfg(user='user";\nip = evil')) + self.assertFalse(result['ok']) + + def test_password_with_double_quote_rejected(self): + result = self.mgr.configure_proxy( + _valid_cfg(user='bob', password='pa"ss')) + self.assertFalse(result['ok']) + + def test_password_with_backslash_rejected(self): + result = self.mgr.configure_proxy( + _valid_cfg(user='bob', password='pa\\ss')) + self.assertFalse(result['ok']) + + def test_password_with_newline_rejected(self): + result = self.mgr.configure_proxy( + _valid_cfg(user='bob', password='pa\nss')) + self.assertFalse(result['ok']) + + def test_password_without_user_rejected(self): + result = self.mgr.configure_proxy(_valid_cfg(password='secret')) + self.assertFalse(result['ok']) + + def test_result_never_contains_password(self): + result = self.mgr.configure_proxy( + _valid_cfg(user='bob', password='topsecret99')) + self.assertNotIn('topsecret99', str(result)) + + +# --------------------------------------------------------------------------- +# configure_proxy — redsocks.conf generation +# --------------------------------------------------------------------------- + +class TestRedsocksConfGeneration(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.mgr = _make_manager(tmp_dir=self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _conf_path(self): + return Path(self.mgr.proxy_dir, 'redsocks.conf') + + def test_socks5_conf_golden(self): + self.mgr.configure_proxy(_valid_cfg()) + expected = ( + 'base {\n' + ' log_debug = off;\n' + ' log_info = on;\n' + ' log = stderr;\n' + ' daemon = off;\n' + ' redirector = iptables;\n' + '}\n' + '\n' + 'redsocks {\n' + ' local_ip = 0.0.0.0;\n' + ' local_port = 12345;\n' + ' ip = proxy.example.com;\n' + ' port = 1080;\n' + ' type = socks5;\n' + '}\n' + ) + self.assertEqual(self._conf_path().read_text(), expected) + + def test_http_conf_uses_http_connect_type(self): + self.mgr.configure_proxy(_valid_cfg(scheme='http', port=3128)) + conf = self._conf_path().read_text() + self.assertIn('type = http-connect;', conf) + self.assertIn('port = 3128;', conf) + + def test_auth_conf_golden_with_login_and_password(self): + self.mgr.configure_proxy(_valid_cfg(user='bob', password='s3cret!')) + conf = self._conf_path().read_text() + self.assertIn(' login = "bob";\n', conf) + self.assertIn(' password = "s3cret!";\n', conf) + + def test_conf_without_auth_has_no_login_lines(self): + self.mgr.configure_proxy(_valid_cfg()) + conf = self._conf_path().read_text() + self.assertNotIn('login', conf) + self.assertNotIn('password', conf) + + def test_conf_file_mode_0600(self): + self.mgr.configure_proxy(_valid_cfg(user='bob', password='s3cret!')) + mode = stat.S_IMODE(os.stat(self._conf_path()).st_mode) + self.assertEqual(mode, 0o600) + + def test_password_not_persisted_in_config_manager(self): + self.mgr.configure_proxy(_valid_cfg(user='bob', password='s3cret!')) + self.mgr.config_manager.set_connectivity_field.assert_called_once() + field, exits = self.mgr.config_manager.set_connectivity_field.call_args[0] + self.assertEqual(field, 'exits') + self.assertEqual(exits['proxy']['scheme'], 'socks5') + self.assertEqual(exits['proxy']['user'], 'bob') + self.assertNotIn('password', exits['proxy']) + + def test_write_failure_returns_ok_false(self): + with patch.object(self.mgr, '_write_secure', side_effect=OSError('disk full')): + result = self.mgr.configure_proxy(_valid_cfg()) + self.assertFalse(result['ok']) + + +# --------------------------------------------------------------------------- +# apply_routes — proxy REDIRECT +# --------------------------------------------------------------------------- + +class TestApplyRoutesProxy(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_proxy_peer_gets_redirect_to_12345(self): + pr = MagicMock() + pr.list_peers.return_value = [ + {'peer': 'bob', 'exit_via': 'proxy'}, + ] + pr.get_peer.return_value = {'peer': 'bob', 'ip': '172.20.0.60/32'} + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + redirect_calls = [ + c for c in mock_sp.run.call_args_list + if 'REDIRECT' in c.args[0] + ] + self.assertEqual(len(redirect_calls), 1) + args = redirect_calls[0].args[0] + self.assertEqual(args[args.index('--to-ports') + 1], '12345') + self.assertIn('172.20.0.60', args) + + def test_proxy_peer_gets_mark_0x50(self): + pr = MagicMock() + pr.list_peers.return_value = [ + {'peer': 'bob', 'exit_via': 'proxy'}, + ] + pr.get_peer.return_value = {'peer': 'bob', 'ip': '172.20.0.60/32'} + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + mark_calls = [ + c for c in mock_sp.run.call_args_list + if 'MARK' in c.args[0] and '172.20.0.60' in c.args[0] + ] + self.assertEqual(len(mark_calls), 1) + args = mark_calls[0].args[0] + self.assertEqual(args[args.index('--set-mark') + 1], hex(0x50)) + + def test_ip_rule_added_for_proxy_table_150(self): + mgr = _make_manager(tmp_dir=self.tmp) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + mgr.apply_routes() + rule_adds = [ + c for c in mock_sp.run.call_args_list + if 'rule' in c.args[0] and 'add' in c.args[0] + and hex(0x50) in c.args[0] + ] + self.assertEqual(len(rule_adds), 1) + self.assertIn('150', rule_adds[0].args[0]) + + def test_tor_redirect_still_uses_9040(self): + """Regression: tor redirect must be unaffected by the new exits.""" + pr = MagicMock() + pr.list_peers.return_value = [ + {'peer': 'carol', 'exit_via': 'tor'}, + ] + pr.get_peer.return_value = {'peer': 'carol', 'ip': '172.20.0.70/32'} + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + redirect_calls = [ + c for c in mock_sp.run.call_args_list + if 'REDIRECT' in c.args[0] + ] + self.assertEqual(len(redirect_calls), 1) + args = redirect_calls[0].args[0] + self.assertEqual(args[args.index('--to-ports') + 1], '9040') + + +# --------------------------------------------------------------------------- +# egress_manager mirror — marks/tables/redirect ports +# --------------------------------------------------------------------------- + +class TestEgressManagerMirror(unittest.TestCase): + + def test_exit_types_include_sshuttle_and_proxy(self): + self.assertIn('sshuttle', em_module.EXIT_TYPES) + self.assertIn('proxy', em_module.EXIT_TYPES) + + def test_marks_do_not_collide_with_connectivity(self): + self.assertEqual(em_module.MARKS['sshuttle'], 0x140) + self.assertEqual(em_module.MARKS['proxy'], 0x150) + self.assertNotIn(em_module.MARKS['sshuttle'], + ConnectivityManager.MARKS.values()) + self.assertNotIn(em_module.MARKS['proxy'], + ConnectivityManager.MARKS.values()) + + def test_tables(self): + self.assertEqual(em_module.TABLES['sshuttle'], 240) + self.assertEqual(em_module.TABLES['proxy'], 250) + + def _make_egress(self, exit_via): + config_manager = MagicMock() + manifest = { + 'id': 'svc', + 'container_name': 'cell-svc', + 'has_egress': True, + 'egress': {'default': exit_via, 'allowed': list(em_module.EXIT_TYPES)}, + } + config_manager.get_installed_services.return_value = { + 'svc': {'manifest': manifest}, + } + config_manager.configs = {'egress_overrides': {}} + return em_module.EgressManager(config_manager=config_manager) + + def test_apply_service_sshuttle_redirects_to_12300(self): + em = self._make_egress('sshuttle') + with patch.object(em_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock( + returncode=0, stdout='172.21.0.5', stderr='') + result = em.apply_service('svc') + self.assertTrue(result['ok'], result) + redirect_calls = [ + c for c in mock_sp.run.call_args_list + if 'REDIRECT' in c.args[0] + ] + self.assertEqual(len(redirect_calls), 1) + args = redirect_calls[0].args[0] + self.assertEqual(args[args.index('--to-ports') + 1], '12300') + + def test_apply_service_proxy_redirects_to_12345(self): + em = self._make_egress('proxy') + with patch.object(em_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock( + returncode=0, stdout='172.21.0.5', stderr='') + result = em.apply_service('svc') + self.assertTrue(result['ok'], result) + redirect_calls = [ + c for c in mock_sp.run.call_args_list + if 'REDIRECT' in c.args[0] + ] + self.assertEqual(len(redirect_calls), 1) + args = redirect_calls[0].args[0] + self.assertEqual(args[args.index('--to-ports') + 1], '12345') + + +# --------------------------------------------------------------------------- +# _exit_status — proxy bridge +# --------------------------------------------------------------------------- + +class TestProxyExitStatus(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _mgr(self, installed=None): + config_manager = MagicMock() + config_manager.get_identity.return_value = {'ip_range': '172.20.0.0/16'} + config_manager.get_installed_services.return_value = installed or {} + return _make_manager(tmp_dir=self.tmp, config_manager=config_manager) + + def test_not_configured_initially(self): + mgr = self._mgr() + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('proxy') + self.assertFalse(info['configured']) + self.assertEqual(info['status'], 'not_configured') + + def test_configured_after_configure_proxy(self): + mgr = self._mgr() + mgr.configure_proxy(_valid_cfg()) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('proxy') + self.assertTrue(info['configured']) + self.assertEqual(info['status'], 'configured') + + def test_configured_when_store_service_installed(self): + mgr = self._mgr(installed={'proxy': {'manifest': {'id': 'proxy'}}}) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('proxy') + self.assertTrue(info['configured']) + + def test_configured_when_redsocks_container_running(self): + mgr = self._mgr() + + def fake_run(cmd, **kwargs): + if 'inspect' in cmd and 'cell-redsocks' in cmd: + return MagicMock(returncode=0, stdout='true\n', stderr='') + return MagicMock(returncode=1, stdout='', stderr='') + + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.side_effect = fake_run + info = mgr._exit_status('proxy') + self.assertTrue(info['configured']) + + def test_proxy_in_list_exits(self): + mgr = self._mgr() + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + exits = mgr.list_exits() + types = {e['type'] for e in exits} + self.assertIn('proxy', types) + + +# --------------------------------------------------------------------------- +# POST /api/connectivity/exits/proxy — route behaviour +# --------------------------------------------------------------------------- + +class TestProxyRoute(unittest.TestCase): + + def setUp(self): + import app as app_module + self.app_module = app_module + app_module.app.config['TESTING'] = True + self.client = app_module.app.test_client() + + def test_valid_config_returns_200_ok_only(self): + mock_cm = MagicMock() + mock_cm.configure_proxy.return_value = {'ok': True} + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/proxy', + json=_valid_cfg(user='bob', password='pw123')) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.get_json(), {'ok': True}) + + def test_invalid_config_returns_400(self): + mock_cm = MagicMock() + mock_cm.configure_proxy.return_value = {'ok': False, 'error': 'invalid scheme'} + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/proxy', + json={'scheme': 'gopher'}) + self.assertEqual(resp.status_code, 400) + self.assertFalse(resp.get_json()['ok']) + + def test_response_never_echoes_password(self): + mock_cm = MagicMock() + mock_cm.configure_proxy.return_value = {'ok': True} + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post( + '/api/connectivity/exits/proxy', + json=_valid_cfg(user='bob', password='ultra-secret-pw')) + self.assertNotIn('ultra-secret-pw', resp.get_data(as_text=True)) + + def test_exception_returns_500_without_details(self): + mock_cm = MagicMock() + mock_cm.configure_proxy.side_effect = Exception('boom secret-detail') + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/proxy', + json=_valid_cfg()) + self.assertEqual(resp.status_code, 500) + self.assertNotIn('secret-detail', resp.get_data(as_text=True)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_connectivity_sshuttle.py b/tests/test_connectivity_sshuttle.py new file mode 100644 index 0000000..ba2e495 --- /dev/null +++ b/tests/test_connectivity_sshuttle.py @@ -0,0 +1,559 @@ +""" +Tests for the sshuttle (SSH tunnel) exit type — configure_sshuttle validation, +config-file generation, vault integration, apply_routes REDIRECT rules, +_exit_status bridging, and the /api/connectivity/exits/sshuttle route +(never echoes secrets). +""" + +import os +import stat +import sys +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'api')) + +import connectivity_manager as cm_module +from connectivity_manager import ConnectivityManager + +_SENTINEL = object() + +VALID_KEY = ( + '-----BEGIN OPENSSH PRIVATE KEY-----\n' + 'b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAAB\n' + '-----END OPENSSH PRIVATE KEY-----\n' +) +VALID_KNOWN_HOSTS = ( + 'ssh.example.com,203.0.113.5 ssh-ed25519 ' + 'AAAAC3NzaC1lZDI1NTE5AAAAIB5d0o0Yw1xP1Yw1xP1Yw1xP1Yw1xP1Yw1xP1Yw1xP1Y' +) + + +def _make_manager(tmp_dir=None, peer_registry=_SENTINEL, config_manager=None, + vault_manager=None): + if tmp_dir is None: + tmp_dir = tempfile.mkdtemp() + + if config_manager is None: + config_manager = MagicMock() + config_manager.get_identity.return_value = { + 'cell_name': 'test', + 'ip_range': '172.20.0.0/16', + } + config_manager.get_connectivity_config.return_value = { + 'exits': {}, 'peer_exit_map': {}, + } + config_manager.get_installed_services.return_value = {} + + if peer_registry is _SENTINEL: + peer_registry = MagicMock() + peer_registry.list_peers.return_value = [] + + with patch.object(ConnectivityManager, '_subscribe_to_events', lambda self: None): + mgr = ConnectivityManager( + config_manager=config_manager, + peer_registry=peer_registry, + vault_manager=vault_manager, + data_dir=tmp_dir, + config_dir=tmp_dir, + ) + return mgr + + +def _valid_cfg(**overrides): + cfg = { + 'host': 'ssh.example.com', + 'port': 22, + 'user': 'tunnel', + 'auth': 'key', + 'private_key': VALID_KEY, + 'known_hosts': VALID_KNOWN_HOSTS, + } + cfg.update(overrides) + return cfg + + +def _mock_subprocess_ok(): + return MagicMock(returncode=0, stdout='', stderr='') + + +# --------------------------------------------------------------------------- +# configure_sshuttle — validation +# --------------------------------------------------------------------------- + +class TestConfigureSshuttleValidation(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.mgr = _make_manager(tmp_dir=self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_valid_config_returns_ok(self): + result = self.mgr.configure_sshuttle(_valid_cfg()) + self.assertTrue(result['ok'], result) + + def test_non_dict_config_rejected(self): + result = self.mgr.configure_sshuttle('not a dict') + self.assertFalse(result['ok']) + + def test_missing_host_rejected(self): + cfg = _valid_cfg() + del cfg['host'] + result = self.mgr.configure_sshuttle(cfg) + self.assertFalse(result['ok']) + self.assertIn('host', result['error']) + + def test_host_with_spaces_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(host='bad host')) + self.assertFalse(result['ok']) + + def test_host_with_shell_metachars_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(host='host;rm -rf /')) + self.assertFalse(result['ok']) + + def test_host_with_double_dots_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(host='a..b')) + self.assertFalse(result['ok']) + + def test_ip_host_accepted(self): + result = self.mgr.configure_sshuttle(_valid_cfg(host='203.0.113.10')) + self.assertTrue(result['ok']) + + def test_port_zero_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(port=0)) + self.assertFalse(result['ok']) + self.assertIn('port', result['error']) + + def test_port_above_65535_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(port=70000)) + self.assertFalse(result['ok']) + + def test_port_non_numeric_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(port='abc')) + self.assertFalse(result['ok']) + + def test_port_defaults_to_22(self): + cfg = _valid_cfg() + del cfg['port'] + result = self.mgr.configure_sshuttle(cfg) + self.assertTrue(result['ok']) + conf = Path(self.mgr.sshuttle_dir, 'sshuttle.conf').read_text() + self.assertIn('PORT=22', conf) + + def test_invalid_user_uppercase_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(user='Tunnel')) + self.assertFalse(result['ok']) + self.assertIn('user', result['error']) + + def test_invalid_user_too_long_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(user='a' * 33)) + self.assertFalse(result['ok']) + + def test_user_starting_with_digit_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(user='1user')) + self.assertFalse(result['ok']) + + def test_invalid_auth_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(auth='agent')) + self.assertFalse(result['ok']) + self.assertIn('auth', result['error']) + + def test_missing_known_hosts_rejected(self): + cfg = _valid_cfg() + del cfg['known_hosts'] + result = self.mgr.configure_sshuttle(cfg) + self.assertFalse(result['ok']) + self.assertIn('known_hosts', result['error']) + + def test_empty_known_hosts_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(known_hosts=' ')) + self.assertFalse(result['ok']) + + def test_known_hosts_with_too_few_fields_rejected(self): + result = self.mgr.configure_sshuttle( + _valid_cfg(known_hosts='ssh.example.com ssh-ed25519')) + self.assertFalse(result['ok']) + + def test_known_hosts_with_bad_keytype_rejected(self): + result = self.mgr.configure_sshuttle( + _valid_cfg(known_hosts='ssh.example.com ssh-dss AAAAB3Nza')) + self.assertFalse(result['ok']) + + def test_known_hosts_with_non_base64_key_rejected(self): + result = self.mgr.configure_sshuttle( + _valid_cfg(known_hosts='ssh.example.com ssh-ed25519 not$base64!')) + self.assertFalse(result['ok']) + + def test_multiline_known_hosts_rejected(self): + kh = VALID_KNOWN_HOSTS + '\nother.example.com ssh-ed25519 AAAAC3Nza' + result = self.mgr.configure_sshuttle(_valid_cfg(known_hosts=kh)) + self.assertFalse(result['ok']) + + def test_strict_host_key_checking_no_rejected(self): + result = self.mgr.configure_sshuttle( + _valid_cfg(host='ssh.example.com -oStrictHostKeyChecking=no')) + self.assertFalse(result['ok']) + self.assertIn('StrictHostKeyChecking', result['error']) + + def test_strict_host_key_checking_no_in_known_hosts_rejected(self): + result = self.mgr.configure_sshuttle( + _valid_cfg(known_hosts='x StrictHostKeyChecking=no y')) + self.assertFalse(result['ok']) + + def test_strict_host_key_checking_no_case_insensitive(self): + result = self.mgr.configure_sshuttle( + _valid_cfg(host='h stricthostkeychecking = NO')) + self.assertFalse(result['ok']) + + def test_key_auth_without_private_key_rejected(self): + cfg = _valid_cfg() + del cfg['private_key'] + result = self.mgr.configure_sshuttle(cfg) + self.assertFalse(result['ok']) + self.assertIn('private_key', result['error']) + + def test_key_auth_with_garbage_key_rejected(self): + result = self.mgr.configure_sshuttle(_valid_cfg(private_key='not a key')) + self.assertFalse(result['ok']) + + def test_password_auth_without_password_rejected(self): + cfg = _valid_cfg(auth='password') + del cfg['private_key'] + result = self.mgr.configure_sshuttle(cfg) + self.assertFalse(result['ok']) + + def test_password_auth_with_password_accepted(self): + cfg = _valid_cfg(auth='password', password='s3cret') + del cfg['private_key'] + result = self.mgr.configure_sshuttle(cfg) + self.assertTrue(result['ok']) + + def test_invalid_exclude_subnet_rejected(self): + result = self.mgr.configure_sshuttle( + _valid_cfg(exclude_subnets=['not-a-cidr'])) + self.assertFalse(result['ok']) + + def test_result_never_contains_secrets(self): + result = self.mgr.configure_sshuttle(_valid_cfg()) + self.assertNotIn('PRIVATE KEY', str(result)) + + +# --------------------------------------------------------------------------- +# configure_sshuttle — file generation +# --------------------------------------------------------------------------- + +class TestConfigureSshuttleFiles(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.vault = MagicMock() + self.mgr = _make_manager(tmp_dir=self.tmp, vault_manager=self.vault) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_sshuttle_conf_golden(self): + self.mgr.configure_sshuttle(_valid_cfg( + exclude_subnets=['172.20.0.0/16', '10.0.0.0/8'])) + conf = Path(self.mgr.sshuttle_dir, 'sshuttle.conf').read_text() + expected = ( + 'HOST=ssh.example.com\n' + 'PORT=22\n' + 'USER=tunnel\n' + 'AUTH=key\n' + 'LISTEN_PORT=12300\n' + 'EXCLUDE=172.20.0.0/16,10.0.0.0/8\n' + ) + self.assertEqual(conf, expected) + + def test_key_file_written_0600(self): + self.mgr.configure_sshuttle(_valid_cfg()) + key_path = Path(self.mgr.sshuttle_dir, 'id_pic') + self.assertTrue(key_path.is_file()) + mode = stat.S_IMODE(os.stat(key_path).st_mode) + self.assertEqual(mode, 0o600) + self.assertIn('PRIVATE KEY', key_path.read_text()) + + def test_known_hosts_file_written_0600(self): + self.mgr.configure_sshuttle(_valid_cfg()) + kh_path = Path(self.mgr.sshuttle_dir, 'known_hosts') + self.assertTrue(kh_path.is_file()) + mode = stat.S_IMODE(os.stat(kh_path).st_mode) + self.assertEqual(mode, 0o600) + self.assertEqual(kh_path.read_text(), VALID_KNOWN_HOSTS + '\n') + + def test_password_file_written_0600_for_password_auth(self): + cfg = _valid_cfg(auth='password', password='s3cret') + del cfg['private_key'] + self.mgr.configure_sshuttle(cfg) + pw_path = Path(self.mgr.sshuttle_dir, 'password') + self.assertTrue(pw_path.is_file()) + mode = stat.S_IMODE(os.stat(pw_path).st_mode) + self.assertEqual(mode, 0o600) + self.assertEqual(pw_path.read_text(), 's3cret\n') + + def test_default_excludes_contain_cell_subnet_and_rfc1918(self): + self.mgr.configure_sshuttle(_valid_cfg()) + conf = Path(self.mgr.sshuttle_dir, 'sshuttle.conf').read_text() + for net in ('172.20.0.0/16', '10.0.0.0/8', '172.16.0.0/12', + '192.168.0.0/16'): + self.assertIn(net, conf) + + def test_key_stored_in_vault(self): + self.mgr.configure_sshuttle(_valid_cfg()) + self.vault.store_secret.assert_called_once_with( + 'connectivity_sshuttle_key', VALID_KEY) + + def test_password_stored_in_vault(self): + cfg = _valid_cfg(auth='password', password='s3cret') + del cfg['private_key'] + self.mgr.configure_sshuttle(cfg) + self.vault.store_secret.assert_called_once_with( + 'connectivity_sshuttle_password', 's3cret') + + def test_non_secret_fields_persisted_in_config_manager(self): + self.mgr.configure_sshuttle(_valid_cfg()) + self.mgr.config_manager.set_connectivity_field.assert_called_once() + field, exits = self.mgr.config_manager.set_connectivity_field.call_args[0] + self.assertEqual(field, 'exits') + self.assertEqual(exits['sshuttle']['host'], 'ssh.example.com') + self.assertNotIn('private_key', exits['sshuttle']) + self.assertNotIn('password', exits['sshuttle']) + self.assertNotIn('known_hosts', exits['sshuttle']) + + def test_write_failure_returns_ok_false(self): + with patch.object(self.mgr, '_write_secure', side_effect=OSError('disk full')): + result = self.mgr.configure_sshuttle(_valid_cfg()) + self.assertFalse(result['ok']) + + +# --------------------------------------------------------------------------- +# apply_routes — sshuttle REDIRECT +# --------------------------------------------------------------------------- + +class TestApplyRoutesSshuttle(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_sshuttle_peer_gets_redirect_to_12300(self): + pr = MagicMock() + pr.list_peers.return_value = [ + {'peer': 'alice', 'exit_via': 'sshuttle'}, + ] + pr.get_peer.return_value = {'peer': 'alice', 'ip': '172.20.0.50/32'} + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + redirect_calls = [ + c for c in mock_sp.run.call_args_list + if 'REDIRECT' in c.args[0] + ] + self.assertEqual(len(redirect_calls), 1) + args = redirect_calls[0].args[0] + self.assertIn('--to-ports', args) + self.assertEqual(args[args.index('--to-ports') + 1], '12300') + self.assertIn('172.20.0.50', args) + + def test_sshuttle_peer_gets_mark_0x40(self): + pr = MagicMock() + pr.list_peers.return_value = [ + {'peer': 'alice', 'exit_via': 'sshuttle'}, + ] + pr.get_peer.return_value = {'peer': 'alice', 'ip': '172.20.0.50/32'} + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + mark_calls = [ + c for c in mock_sp.run.call_args_list + if 'MARK' in c.args[0] and '172.20.0.50' in c.args[0] + ] + self.assertEqual(len(mark_calls), 1) + args = mark_calls[0].args[0] + self.assertEqual(args[args.index('--set-mark') + 1], hex(0x40)) + + def test_ip_rule_added_for_sshuttle_table_140(self): + mgr = _make_manager(tmp_dir=self.tmp) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + mgr.apply_routes() + rule_adds = [ + c for c in mock_sp.run.call_args_list + if 'rule' in c.args[0] and 'add' in c.args[0] + and hex(0x40) in c.args[0] + ] + self.assertEqual(len(rule_adds), 1) + self.assertIn('140', rule_adds[0].args[0]) + + def test_no_killswitch_for_sshuttle(self): + """sshuttle has no exit iface — _add_killswitch must skip it.""" + mgr = _make_manager(tmp_dir=self.tmp) + self.assertNotIn('sshuttle', mgr.IFACES) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr._add_killswitch(0x40, None) + mock_sp.run.assert_not_called() + + +# --------------------------------------------------------------------------- +# _exit_status — sshuttle bridge +# --------------------------------------------------------------------------- + +class TestSshuttleExitStatus(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _mgr(self, installed=None): + config_manager = MagicMock() + config_manager.get_identity.return_value = {'ip_range': '172.20.0.0/16'} + config_manager.get_installed_services.return_value = installed or {} + return _make_manager(tmp_dir=self.tmp, config_manager=config_manager) + + def test_not_configured_initially(self): + mgr = self._mgr() + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('sshuttle') + self.assertFalse(info['configured']) + self.assertEqual(info['status'], 'not_configured') + + def test_configured_after_configure_sshuttle(self): + mgr = self._mgr() + mgr.configure_sshuttle(_valid_cfg()) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('sshuttle') + self.assertTrue(info['configured']) + self.assertEqual(info['status'], 'configured') + + def test_configured_when_store_service_installed(self): + mgr = self._mgr(installed={'sshuttle': {'manifest': {'id': 'sshuttle'}}}) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') + info = mgr._exit_status('sshuttle') + self.assertTrue(info['configured']) + + def test_configured_when_container_running(self): + mgr = self._mgr() + + def fake_run(cmd, **kwargs): + if 'inspect' in cmd and 'cell-sshuttle' in cmd: + return MagicMock(returncode=0, stdout='true\n', stderr='') + return MagicMock(returncode=1, stdout='', stderr='') + + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.side_effect = fake_run + info = mgr._exit_status('sshuttle') + self.assertTrue(info['configured']) + + def test_sshuttle_in_list_exits(self): + mgr = self._mgr() + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + exits = mgr.list_exits() + types = {e['type'] for e in exits} + self.assertIn('sshuttle', types) + + +# --------------------------------------------------------------------------- +# set_peer_exit accepts sshuttle +# --------------------------------------------------------------------------- + +class TestSetPeerExitSshuttle(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_sshuttle_is_a_valid_exit_type(self): + pr = MagicMock() + pr.set_peer_exit_via.return_value = True + pr.list_peers.return_value = [] + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + result = mgr.set_peer_exit('alice', 'sshuttle') + self.assertTrue(result['ok']) + + def test_peer_registry_accepts_sshuttle(self): + from peer_registry import PeerRegistry + self.assertIn('sshuttle', PeerRegistry.VALID_EXIT_VIA) + + +# --------------------------------------------------------------------------- +# POST /api/connectivity/exits/sshuttle — route behaviour +# --------------------------------------------------------------------------- + +class TestSshuttleRoute(unittest.TestCase): + + def setUp(self): + import app as app_module + self.app_module = app_module + app_module.app.config['TESTING'] = True + self.client = app_module.app.test_client() + + def test_valid_config_returns_200_ok_only(self): + mock_cm = MagicMock() + mock_cm.configure_sshuttle.return_value = {'ok': True} + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/sshuttle', + json=_valid_cfg()) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.get_json(), {'ok': True}) + + def test_invalid_config_returns_400(self): + mock_cm = MagicMock() + mock_cm.configure_sshuttle.return_value = {'ok': False, 'error': 'invalid host'} + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/sshuttle', + json={'host': '!!'}) + self.assertEqual(resp.status_code, 400) + self.assertFalse(resp.get_json()['ok']) + + def test_response_never_echoes_private_key(self): + mock_cm = MagicMock() + mock_cm.configure_sshuttle.return_value = {'ok': True} + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/sshuttle', + json=_valid_cfg()) + body = resp.get_data(as_text=True) + self.assertNotIn('PRIVATE KEY', body) + self.assertNotIn(VALID_KEY.splitlines()[1], body) + + def test_response_never_echoes_password(self): + mock_cm = MagicMock() + mock_cm.configure_sshuttle.return_value = {'ok': True} + cfg = _valid_cfg(auth='password', password='hunter2-secret') + del cfg['private_key'] + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/sshuttle', json=cfg) + self.assertNotIn('hunter2-secret', resp.get_data(as_text=True)) + + def test_exception_returns_500_without_details(self): + mock_cm = MagicMock() + mock_cm.configure_sshuttle.side_effect = Exception('boom PRIVATE stuff') + with patch.object(self.app_module, 'connectivity_manager', mock_cm): + resp = self.client.post('/api/connectivity/exits/sshuttle', + json=_valid_cfg()) + self.assertEqual(resp.status_code, 500) + self.assertNotIn('PRIVATE', resp.get_data(as_text=True)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_ddns_endpoints.py b/tests/test_ddns_endpoints.py index 50ae230..8802db9 100644 --- a/tests/test_ddns_endpoints.py +++ b/tests/test_ddns_endpoints.py @@ -169,5 +169,29 @@ class TestDdnsRegister(unittest.TestCase): self.assertTrue(body['registered']) self.assertEqual(body['subdomain'], 'mypic.pic.ngo') + +class TestDdnsSyncRecords(unittest.TestCase): + + def setUp(self): + self.client = _make_client() + + def test_sync_success(self): + from app import ddns_manager + with patch.object(ddns_manager, 'sync_service_records', + return_value={'success': True, 'synced': ['a'], 'failed': []}): + r = self.client.post('/api/ddns/sync') + self.assertEqual(r.status_code, 200) + self.assertTrue(json.loads(r.data)['success']) + + def test_sync_ddns_error_returns_400(self): + from app import ddns_manager + from ddns_manager import DDNSError + with patch.object(ddns_manager, 'sync_service_records', + side_effect=DDNSError('no provider')): + r = self.client.post('/api/ddns/sync') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_ddns_manager.py b/tests/test_ddns_manager.py index 08151b3..c0ffece 100644 --- a/tests/test_ddns_manager.py +++ b/tests/test_ddns_manager.py @@ -17,8 +17,6 @@ from ddns_manager import ( PicNgoDDNS, CloudflareDDNS, DuckDNSDDNS, - NoIPDDNS, - FreeDNSDDNS, _get_public_ip, ) @@ -155,7 +153,8 @@ class TestPicNgoDDNSChallenges(unittest.TestCase): mock_post.assert_called_once() args, kwargs = mock_post.call_args self.assertEqual(args[0], 'https://ddns.example.com/api/v1/dns-challenge') - self.assertEqual(kwargs['json'], {'fqdn': '_acme.alpha.pic.ngo', 'value': 'abc123'}) + self.assertEqual(kwargs['json'], + {'fqdn': '_acme.alpha.pic.ngo', 'value': 'abc123', 'token': 'tok'}) self.assertEqual(kwargs['headers']['Authorization'], 'Bearer tok') self.assertTrue(result) @@ -167,7 +166,7 @@ class TestPicNgoDDNSChallenges(unittest.TestCase): mock_del.assert_called_once() args, kwargs = mock_del.call_args self.assertEqual(args[0], 'https://ddns.example.com/api/v1/dns-challenge') - self.assertEqual(kwargs['json'], {'fqdn': '_acme.alpha.pic.ngo'}) + self.assertEqual(kwargs['json'], {'fqdn': '_acme.alpha.pic.ngo', 'token': 'tok'}) self.assertEqual(kwargs['headers']['Authorization'], 'Bearer tok') self.assertTrue(result) @@ -186,6 +185,236 @@ class TestPicNgoDDNSChallenges(unittest.TestCase): provider.dns_challenge_delete('tok', 'fqdn') +# --------------------------------------------------------------------------- +# CloudflareDDNS tests +# --------------------------------------------------------------------------- + +class TestCloudflareDDNSUpdate(unittest.TestCase): + """CloudflareDDNS.update() looks up the A record id, then PATCHes that record.""" + + def _provider(self, domain='cell.example.com'): + return CloudflareDDNS(api_token='cf_tok', zone_id='zid123', domain=domain) + + def test_update_gets_record_id_then_patches_it(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': [{'id': 'rec42'}]}) + patch_resp = _make_response(200) + with patch('requests.get', return_value=get_resp) as mock_get, \ + patch('requests.patch', return_value=patch_resp) as mock_patch: + result = provider.update('unused-token', '5.6.7.8') + + self.assertTrue(result) + # Lookup: GET the dns_records collection filtered by type+name + mock_get.assert_called_once() + get_args, get_kwargs = mock_get.call_args + self.assertEqual( + get_args[0], + 'https://api.cloudflare.com/client/v4/zones/zid123/dns_records', + ) + self.assertEqual(get_kwargs['params'], {'type': 'A', 'name': 'cell.example.com'}) + # Update: PATCH the individual record with the Cloudflare payload shape + mock_patch.assert_called_once() + patch_args, patch_kwargs = mock_patch.call_args + self.assertEqual( + patch_args[0], + 'https://api.cloudflare.com/client/v4/zones/zid123/dns_records/rec42', + ) + self.assertEqual( + patch_kwargs['json'], + {'type': 'A', 'name': 'cell.example.com', 'content': '5.6.7.8'}, + ) + self.assertEqual( + patch_kwargs['headers']['Authorization'], 'Bearer cf_tok' + ) + + def test_update_returns_false_when_record_not_found(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': []}) + with patch('requests.get', return_value=get_resp), \ + patch('requests.patch') as mock_patch: + result = provider.update('tok', '1.2.3.4') + self.assertFalse(result) + mock_patch.assert_not_called() + + def test_update_returns_false_when_lookup_fails(self): + provider = self._provider() + get_resp = _make_response(403, text='forbidden') + with patch('requests.get', return_value=get_resp), \ + patch('requests.patch') as mock_patch: + result = provider.update('tok', '1.2.3.4') + self.assertFalse(result) + mock_patch.assert_not_called() + + def test_update_returns_false_when_patch_fails(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': [{'id': 'rec1'}]}) + patch_resp = _make_response(500, text='server error') + with patch('requests.get', return_value=get_resp), \ + patch('requests.patch', return_value=patch_resp): + result = provider.update('tok', '1.2.3.4') + self.assertFalse(result) + + def test_update_returns_false_without_domain(self): + provider = self._provider(domain='') + with patch('requests.get') as mock_get: + result = provider.update('tok', '1.2.3.4') + self.assertFalse(result) + mock_get.assert_not_called() + + +class TestCloudflareDDNSSyncServiceRecords(unittest.TestCase): + """CloudflareDDNS.sync_service_records() ensures one A record per name.""" + + def _provider(self, domain='cell.example.com'): + return CloudflareDDNS(api_token='cf_tok', zone_id='zid123', domain=domain) + + def test_creates_missing_record_with_post(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': []}) + post_resp = _make_response(200) + with patch('requests.get', return_value=get_resp), \ + patch('requests.post', return_value=post_resp) as mock_post, \ + patch('requests.patch') as mock_patch: + result = provider.sync_service_records(['mail.cell.example.com'], '9.9.9.9') + self.assertTrue(result['success']) + self.assertIn('cell.example.com', result['synced']) + self.assertIn('mail.cell.example.com', result['synced']) + self.assertEqual(result['failed'], []) + mock_patch.assert_not_called() + # apex + one subdomain = 2 POSTs (both missing) + self.assertEqual(mock_post.call_count, 2) + _, kwargs = mock_post.call_args + self.assertEqual(kwargs['json']['type'], 'A') + self.assertEqual(kwargs['json']['content'], '9.9.9.9') + + def test_updates_existing_record_with_patch(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': [{'id': 'rec7'}]}) + patch_resp = _make_response(200) + with patch('requests.get', return_value=get_resp), \ + patch('requests.patch', return_value=patch_resp) as mock_patch, \ + patch('requests.post') as mock_post: + result = provider.sync_service_records(['mail.cell.example.com'], '9.9.9.9') + self.assertTrue(result['success']) + mock_post.assert_not_called() + self.assertEqual(mock_patch.call_count, 2) + patch_args, _ = mock_patch.call_args + self.assertIn('/dns_records/rec7', patch_args[0]) + + def test_reports_failure_when_write_fails(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': []}) + post_resp = _make_response(500, text='server error') + with patch('requests.get', return_value=get_resp), \ + patch('requests.post', return_value=post_resp): + result = provider.sync_service_records(['mail.cell.example.com'], '9.9.9.9') + self.assertFalse(result['success']) + self.assertEqual(set(result['failed']), + {'cell.example.com', 'mail.cell.example.com'}) + + def test_handles_lookup_error_as_failure(self): + provider = self._provider() + get_resp = _make_response(403, text='forbidden') + with patch('requests.get', return_value=get_resp), \ + patch('requests.post') as mock_post, \ + patch('requests.patch') as mock_patch: + result = provider.sync_service_records([], '9.9.9.9') + self.assertFalse(result['success']) + self.assertIn('cell.example.com', result['failed']) + mock_post.assert_not_called() + mock_patch.assert_not_called() + + def test_no_domain_returns_unsuccessful(self): + provider = self._provider(domain='') + with patch('requests.get') as mock_get: + result = provider.sync_service_records(['x.example.com'], '9.9.9.9') + self.assertFalse(result['success']) + mock_get.assert_not_called() + + def test_dedupes_apex_in_subdomains(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': []}) + post_resp = _make_response(200) + with patch('requests.get', return_value=get_resp), \ + patch('requests.post', return_value=post_resp) as mock_post: + result = provider.sync_service_records(['cell.example.com'], '9.9.9.9') + # apex passed again as a subdomain must not double-write + self.assertEqual(result['synced'], ['cell.example.com']) + self.assertEqual(mock_post.call_count, 1) + + +class TestCloudflareDDNSChallenges(unittest.TestCase): + """CloudflareDDNS DNS-01 challenge record creation and deletion.""" + + def _provider(self): + return CloudflareDDNS(api_token='cf_tok', zone_id='zid123', domain='cell.example.com') + + def test_dns_challenge_create_posts_txt_record(self): + provider = self._provider() + post_resp = _make_response(200) + with patch('requests.post', return_value=post_resp) as mock_post: + result = provider.dns_challenge_create('tok', '_acme.cell.example.com', 'val') + self.assertTrue(result) + _, kwargs = mock_post.call_args + self.assertEqual(kwargs['json']['type'], 'TXT') + self.assertEqual(kwargs['json']['name'], '_acme.cell.example.com') + self.assertEqual(kwargs['json']['content'], 'val') + + def test_dns_challenge_delete_deletes_record_by_id(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': [{'id': 'txt9'}]}) + del_resp = _make_response(200) + with patch('requests.get', return_value=get_resp) as mock_get, \ + patch('requests.delete', return_value=del_resp) as mock_del: + result = provider.dns_challenge_delete('tok', '_acme.cell.example.com') + self.assertTrue(result) + _, get_kwargs = mock_get.call_args + self.assertEqual(get_kwargs['params'], {'type': 'TXT', 'name': '_acme.cell.example.com'}) + del_args, _ = mock_del.call_args + self.assertEqual( + del_args[0], + 'https://api.cloudflare.com/client/v4/zones/zid123/dns_records/txt9', + ) + + def test_dns_challenge_delete_deletes_all_matching_records(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': [{'id': 'a'}, {'id': 'b'}]}) + del_resp = _make_response(200) + with patch('requests.get', return_value=get_resp), \ + patch('requests.delete', return_value=del_resp) as mock_del: + result = provider.dns_challenge_delete('tok', '_acme.cell.example.com') + self.assertTrue(result) + self.assertEqual(mock_del.call_count, 2) + + def test_dns_challenge_delete_returns_false_when_no_record(self): + """Must NOT pretend success when there is nothing to delete.""" + provider = self._provider() + get_resp = _make_response(200, json_data={'result': []}) + with patch('requests.get', return_value=get_resp), \ + patch('requests.delete') as mock_del: + result = provider.dns_challenge_delete('tok', '_acme.cell.example.com') + self.assertFalse(result) + mock_del.assert_not_called() + + def test_dns_challenge_delete_returns_false_when_delete_fails(self): + provider = self._provider() + get_resp = _make_response(200, json_data={'result': [{'id': 'txt9'}]}) + del_resp = _make_response(500, text='error') + with patch('requests.get', return_value=get_resp), \ + patch('requests.delete', return_value=del_resp): + result = provider.dns_challenge_delete('tok', '_acme.cell.example.com') + self.assertFalse(result) + + def test_dns_challenge_delete_returns_false_when_lookup_fails(self): + provider = self._provider() + get_resp = _make_response(401, text='unauthorized') + with patch('requests.get', return_value=get_resp), \ + patch('requests.delete') as mock_del: + result = provider.dns_challenge_delete('tok', '_acme.cell.example.com') + self.assertFalse(result) + mock_del.assert_not_called() + + # --------------------------------------------------------------------------- # DDNSManager.get_provider() tests # --------------------------------------------------------------------------- @@ -231,17 +460,51 @@ class TestGetProvider(unittest.TestCase): provider = mgr.get_provider() self.assertIsInstance(provider, DuckDNSDDNS) - def test_returns_noip_provider(self): + def test_noip_provider_rejected(self): + """'noip' is not yet supported — get_provider() must fail loudly.""" cm = _make_config_manager(ddns_cfg={'provider': 'noip'}) mgr = DDNSManager(config_manager=cm) - provider = mgr.get_provider() - self.assertIsInstance(provider, NoIPDDNS) + with self.assertRaises(DDNSError) as ctx: + mgr.get_provider() + self.assertIn('not yet supported', str(ctx.exception)) - def test_returns_freedns_provider(self): + def test_freedns_provider_rejected(self): + """'freedns' is not yet supported — get_provider() must fail loudly.""" cm = _make_config_manager(ddns_cfg={'provider': 'freedns'}) mgr = DDNSManager(config_manager=cm) + with self.assertRaises(DDNSError) as ctx: + mgr.get_provider() + self.assertIn('not yet supported', str(ctx.exception)) + + def test_test_connectivity_reports_unsupported_provider(self): + """test_connectivity() must not raise for unsupported providers.""" + cm = _make_config_manager(ddns_cfg={'provider': 'noip'}) + mgr = DDNSManager(config_manager=cm) + result = mgr.test_connectivity() + self.assertFalse(result['success']) + self.assertIn('not yet supported', result['reason']) + + def test_cloudflare_provider_gets_domain_from_config(self): + cm = _make_config_manager(ddns_cfg={ + 'provider': 'cloudflare', + 'api_token': 'cf_tok', + 'zone_id': 'zid', + 'domain': 'cell.example.com', + }) + mgr = DDNSManager(config_manager=cm) provider = mgr.get_provider() - self.assertIsInstance(provider, FreeDNSDDNS) + self.assertEqual(provider.domain, 'cell.example.com') + + def test_cloudflare_provider_falls_back_to_identity_domain(self): + cm = _make_config_manager(ddns_cfg={ + 'provider': 'cloudflare', + 'api_token': 'cf_tok', + 'zone_id': 'zid', + }) + cm.get_identity.return_value = {'domain_name': 'ident.example.com'} + mgr = DDNSManager(config_manager=cm) + provider = mgr.get_provider() + self.assertEqual(provider.domain, 'ident.example.com') def test_returns_none_for_unknown_provider(self): cm = _make_config_manager(ddns_cfg={'provider': 'nonexistent'}) @@ -260,6 +523,60 @@ class TestGetProvider(unittest.TestCase): self.assertEqual(provider.api_base_url, 'https://custom.example.com') +# --------------------------------------------------------------------------- +# DDNSManager.sync_service_records() tests +# --------------------------------------------------------------------------- + +class TestManagerSyncServiceRecords(unittest.TestCase): + """DDNSManager.sync_service_records builds names and delegates to the provider.""" + + def _manager(self, routes): + cm = _make_config_manager(ddns_cfg={'provider': 'cloudflare'}) + cm.get_effective_domain.return_value = 'cell.example.com' + registry = MagicMock() + registry.get_caddy_routes.return_value = routes + mgr = DDNSManager(config_manager=cm, service_registry=registry) + return mgr + + def test_delegates_to_provider_with_fqdns(self): + mgr = self._manager([ + {'subdomain': 'mail', 'extra_subdomains': []}, + {'subdomain': 'cal', 'extra_subdomains': ['dav']}, + ]) + provider = MagicMock() + provider.sync_service_records.return_value = {'success': True, 'synced': [], 'failed': []} + mgr.get_provider = MagicMock(return_value=provider) + with patch('ddns_manager._get_public_ip', return_value='7.7.7.7'): + result = mgr.sync_service_records() + self.assertTrue(result['success']) + names, ip = provider.sync_service_records.call_args[0] + self.assertEqual(ip, '7.7.7.7') + self.assertIn('mail.cell.example.com', names) + self.assertIn('cal.cell.example.com', names) + self.assertIn('dav.cell.example.com', names) + + def test_raises_when_no_provider(self): + mgr = self._manager([]) + mgr.get_provider = MagicMock(return_value=None) + with self.assertRaises(DDNSError): + mgr.sync_service_records() + + def test_raises_when_provider_lacks_support(self): + mgr = self._manager([]) + provider = MagicMock(spec=['update', 'register']) + mgr.get_provider = MagicMock(return_value=provider) + with self.assertRaises(DDNSError): + mgr.sync_service_records() + + def test_raises_when_no_public_ip(self): + mgr = self._manager([]) + provider = MagicMock() + mgr.get_provider = MagicMock(return_value=provider) + with patch('ddns_manager._get_public_ip', return_value=None): + with self.assertRaises(DDNSError): + mgr.sync_service_records() + + # --------------------------------------------------------------------------- # DDNSManager.update_ip() tests # --------------------------------------------------------------------------- diff --git a/tests/test_enhanced_api.py b/tests/test_enhanced_api.py new file mode 100644 index 0000000..5f96f1b --- /dev/null +++ b/tests/test_enhanced_api.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python3 +""" +Comprehensive Test Suite for Enhanced Personal Internet Cell API +Tests all new components and integrations +""" + +import unittest +import json +import tempfile +import os +import shutil +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock +import sys +import threading +import time + +# Add the api directory to the path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from base_service_manager import BaseServiceManager +from config_manager import ConfigManager +from service_bus import ServiceBus, EventType, Event +from log_manager import LogManager, LogLevel +from network_manager import NetworkManager +from enhanced_cli import APIClient, ConfigManager as CLIConfigManager, EnhancedCLI + +class TestBaseServiceManager(unittest.TestCase): + """Test the base service manager functionality""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.data_dir = os.path.join(self.temp_dir, 'data') + self.config_dir = os.path.join(self.temp_dir, 'config') + os.makedirs(self.data_dir, exist_ok=True) + os.makedirs(self.config_dir, exist_ok=True) + + # Create a concrete implementation for testing + class TestServiceManager(BaseServiceManager): + def get_status(self): + return {'running': True, 'status': 'online'} + + def test_connectivity(self): + return {'success': True, 'message': 'Connected'} + + self.service_manager = TestServiceManager('test_service', self.data_dir, self.config_dir) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_initialization(self): + """Test service manager initialization""" + self.assertEqual(self.service_manager.service_name, 'test_service') + self.assertEqual(self.service_manager.data_dir, self.data_dir) + self.assertEqual(self.service_manager.config_dir, self.config_dir) + self.assertTrue(os.path.exists(self.data_dir)) + self.assertTrue(os.path.exists(self.config_dir)) + + def test_get_status(self): + """Test get_status method""" + status = self.service_manager.get_status() + self.assertEqual(status['running'], True) + self.assertEqual(status['status'], 'online') + + def test_test_connectivity(self): + """Test test_connectivity method""" + connectivity = self.service_manager.test_connectivity() + self.assertEqual(connectivity['success'], True) + self.assertEqual(connectivity['message'], 'Connected') + + def test_get_logs(self): + """Test get_logs method""" + # Create a test log file + log_file = os.path.join(self.data_dir, 'test_service.log') + with open(log_file, 'w') as f: + f.write("Test log line 1\n") + f.write("Test log line 2\n") + + logs = self.service_manager.get_logs(lines=2) + self.assertEqual(len(logs), 2) + self.assertIn("Test log line 1", logs[0]) + self.assertIn("Test log line 2", logs[1]) + + def test_get_config(self): + """Test get_config method""" + # Create a test config file + config_file = os.path.join(self.config_dir, 'test_service.json') + test_config = {'key': 'value', 'number': 42} + with open(config_file, 'w') as f: + json.dump(test_config, f) + + config = self.service_manager.get_config() + self.assertEqual(config['key'], 'value') + self.assertEqual(config['number'], 42) + + def test_update_config(self): + """Test update_config method""" + test_config = {'new_key': 'new_value', 'number': 100} + success = self.service_manager.update_config(test_config) + self.assertTrue(success) + + # Verify config was saved + config = self.service_manager.get_config() + self.assertEqual(config['new_key'], 'new_value') + self.assertEqual(config['number'], 100) + + def test_validate_config(self): + """Test validate_config method""" + test_config = {'key': 'value'} + validation = self.service_manager.validate_config(test_config) + self.assertTrue(validation['valid']) + self.assertEqual(len(validation['errors']), 0) + + def test_get_metrics(self): + """Test get_metrics method""" + metrics = self.service_manager.get_metrics() + self.assertEqual(metrics['service'], 'test_service') + self.assertIn('timestamp', metrics) + self.assertEqual(metrics['status'], 'unknown') + + def test_handle_error(self): + """Test handle_error method""" + test_error = ValueError("Test error") + error_info = self.service_manager.handle_error(test_error, "test_context") + + self.assertEqual(error_info['error'], "Test error") + self.assertEqual(error_info['type'], "ValueError") + self.assertEqual(error_info['context'], "test_context") + self.assertEqual(error_info['service'], 'test_service') + self.assertIn('traceback', error_info) + + def test_health_check(self): + """Test health_check method""" + health = self.service_manager.health_check() + + self.assertEqual(health['service'], 'test_service') + self.assertIn('timestamp', health) + self.assertIn('status', health) + self.assertIn('connectivity', health) + self.assertIn('metrics', health) + self.assertIn('healthy', health) + self.assertTrue(health['healthy']) + +class TestConfigManager(unittest.TestCase): + """Test the configuration manager functionality""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.config_dir = os.path.join(self.temp_dir, 'config') + self.data_dir = os.path.join(self.temp_dir, 'data') + os.makedirs(self.config_dir, exist_ok=True) + os.makedirs(self.data_dir, exist_ok=True) + + self.config_file = os.path.join(self.config_dir, 'cell_config.json') + assert not os.path.isdir(self.config_file), f"self.config_file is a directory: {self.config_file}" + print(f"[DEBUG] TestConfigManager.setUp: self.config_file = {self.config_file}") + # Ensure the config file exists and is a valid JSON file + if not os.path.exists(self.config_file): + with open(self.config_file, 'w') as f: + json.dump({}, f) + self.config_manager = ConfigManager(self.config_file, self.data_dir) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + if os.path.exists(self.config_file): + os.remove(self.config_file) + + def test_initialization(self): + assert not os.path.isdir(self.config_file), f"self.config_file is a directory: {self.config_file}" + print(f"[DEBUG] test_initialization: self.config_file = {self.config_file}") + """Test config manager initialization""" + self.assertTrue(os.path.exists(self.config_dir)) + self.assertTrue(os.path.exists(self.data_dir)) + self.assertTrue(os.path.exists(self.config_manager.backup_dir)) + self.assertIsNotNone(self.config_manager.service_schemas) + + def test_get_service_config(self): + assert not os.path.isdir(self.config_file), f"self.config_file is a directory: {self.config_file}" + print(f"[DEBUG] test_get_service_config: self.config_file = {self.config_file}") + """Test getting service configuration""" + # Test with non-existent service + with self.assertRaises(ValueError): + self.config_manager.get_service_config('nonexistent_service') + + # Test with valid service + config = self.config_manager.get_service_config('network') + self.assertEqual(config, {}) + + def test_update_service_config(self): + assert not os.path.isdir(self.config_file), f"self.config_file is a directory: {self.config_file}" + print(f"[DEBUG] test_update_service_config: self.config_file = {self.config_file}") + """Test updating service configuration""" + test_config = { + 'dns_port': 53, + 'dhcp_range': '10.0.0.100-10.0.0.200', + 'ntp_servers': ['pool.ntp.org'] + } + + success = self.config_manager.update_service_config('network', test_config) + self.assertTrue(success) + + # Verify config was saved + config = self.config_manager.get_service_config('network') + self.assertEqual(config['dns_port'], 53) + self.assertEqual(config['dhcp_range'], '10.0.0.100-10.0.0.200') + self.assertEqual(config['ntp_servers'], ['pool.ntp.org']) + + def test_validate_config(self): + assert not os.path.isdir(self.config_file), f"self.config_file is a directory: {self.config_file}" + print(f"[DEBUG] test_validate_config: self.config_file = {self.config_file}") + """Test configuration validation""" + # Test valid config + valid_config = { + 'dns_port': 53, + 'dhcp_range': '10.0.0.100-10.0.0.200', + 'ntp_servers': ['pool.ntp.org'] + } + validation = self.config_manager.validate_config('network', valid_config) + self.assertTrue(validation['valid']) + self.assertEqual(len(validation['errors']), 0) + + # Test invalid config (missing required field) + invalid_config = { + 'dns_port': 53 + # Missing ntp_servers + } + validation = self.config_manager.validate_config('network', invalid_config) + self.assertFalse(validation['valid']) + self.assertGreater(len(validation['errors']), 0) + + # Test invalid config (wrong type) + invalid_type_config = { + 'dns_port': 'not_a_number', + 'dhcp_range': '10.0.0.100-10.0.0.200', + 'ntp_servers': ['pool.ntp.org'] + } + validation = self.config_manager.validate_config('network', invalid_type_config) + self.assertFalse(validation['valid']) + self.assertGreater(len(validation['errors']), 0) + + def test_backup_and_restore(self): + assert not os.path.isdir(self.config_file), f"self.config_file is a directory: {self.config_file}" + print(f"[DEBUG] test_backup_and_restore: self.config_file = {self.config_file}") + """Test configuration backup and restore""" + # Create some test configurations + test_configs = { + 'network': {'dns_port': 53, 'dhcp_range': '10.0.0.100-10.0.0.200', 'ntp_servers': ['pool.ntp.org']}, + 'wireguard': {'port': 51820, 'private_key': 'test_key', 'address': '10.0.0.1/24'} + } + + for service, config in test_configs.items(): + self.config_manager.update_service_config(service, config) + + # Create backup + backup_id = self.config_manager.backup_config() + self.assertIsNotNone(backup_id) + + # List backups + backups = self.config_manager.list_backups() + self.assertEqual(len(backups), 1) + self.assertEqual(backups[0]['backup_id'], backup_id) + + # Modify config + self.config_manager.update_service_config('network', {'dns_port': 5353}) + + # Restore backup + success = self.config_manager.restore_config(backup_id) + self.assertTrue(success) + + # Verify restoration + config = self.config_manager.get_service_config('network') + self.assertEqual(config['dns_port'], 53) # Should be restored value + + def test_export_import_config(self): + assert not os.path.isdir(self.config_file), f"self.config_file is a directory: {self.config_file}" + print(f"[DEBUG] test_export_import_config: self.config_file = {self.config_file}") + """Test configuration export and import""" + # Create test configurations + test_configs = { + 'network': {'dns_port': 53, 'dhcp_range': '10.0.0.100-10.0.0.200', 'ntp_servers': ['pool.ntp.org']}, + 'wireguard': {'port': 51820, 'private_key': 'test_key', 'address': '10.0.0.1/24'} + } + + for service, config in test_configs.items(): + self.config_manager.update_service_config(service, config) + + # Export configuration + exported_json = self.config_manager.export_config('json') + exported_yaml = self.config_manager.export_config('yaml') + + self.assertIsInstance(exported_json, str) + self.assertIsInstance(exported_yaml, str) + + # Clear unified config file + if os.path.exists(self.config_file): + os.remove(self.config_file) + + # Import configuration + success = self.config_manager.import_config(exported_json, 'json') + self.assertTrue(success) + + # Verify import + for service, expected_config in test_configs.items(): + config = self.config_manager.get_service_config(service) + for key, value in expected_config.items(): + self.assertEqual(config[key], value) + + # Also verify that required fields are present (even if with default values) + schema = self.config_manager.service_schemas[service] + for field in schema['required']: + self.assertIn(field, config) + +class TestServiceBus(unittest.TestCase): + """Test the service bus functionality""" + + def setUp(self): + self.service_bus = ServiceBus() + + def test_initialization(self): + """Test service bus initialization""" + self.assertFalse(self.service_bus.running) + self.assertEqual(len(self.service_bus.service_registry), 0) + self.assertEqual(len(self.service_bus.event_handlers), 0) + + def test_start_stop(self): + """Test service bus start and stop""" + self.service_bus.start() + self.assertTrue(self.service_bus.running) + self.assertIsNotNone(self.service_bus.event_loop_thread) + + self.service_bus.stop() + self.assertFalse(self.service_bus.running) + + def test_register_unregister_service(self): + """Test service registration and unregistration""" + mock_service = Mock() + mock_service.get_status.return_value = {'running': True} + + # Register service + self.service_bus.register_service('test_service', mock_service) + self.assertIn('test_service', self.service_bus.service_registry) + self.assertEqual(self.service_bus.service_registry['test_service'], mock_service) + + # Unregister service + self.service_bus.unregister_service('test_service') + self.assertNotIn('test_service', self.service_bus.service_registry) + + def test_publish_subscribe_events(self): + """Test event publishing and subscription""" + events_received = [] + + def event_handler(event): + events_received.append(event) + + # Subscribe to events + self.service_bus.subscribe_to_event(EventType.SERVICE_STARTED, event_handler) + + # Start service bus + self.service_bus.start() + + # Publish event + test_data = {'service': 'test_service', 'timestamp': datetime.utcnow().isoformat()} + self.service_bus.publish_event(EventType.SERVICE_STARTED, 'test_service', test_data) + + # Wait for event processing + time.sleep(0.1) + + # Check if event was received + self.assertEqual(len(events_received), 1) + self.assertEqual(events_received[0].event_type, EventType.SERVICE_STARTED) + self.assertEqual(events_received[0].source, 'test_service') + self.assertEqual(events_received[0].data, test_data) + + self.service_bus.stop() + + def test_call_service(self): + """Test service method calling""" + # Create a real service class instead of Mock + class TestService: + def test_method(self, arg1=None): + return 'test_result' + + test_service = TestService() + self.service_bus.register_service('test_service', test_service) + + # Call service method + result = self.service_bus.call_service('test_service', 'test_method', arg1='value1') + self.assertEqual(result, 'test_result') + + # Test calling non-existent service + with self.assertRaises(ValueError): + self.service_bus.call_service('nonexistent_service', 'test_method') + + # Test calling non-existent method + with self.assertRaises(ValueError): + self.service_bus.call_service('test_service', 'nonexistent_method') + + def test_service_orchestration(self): + """Test service orchestration""" + mock_service = Mock() + mock_service.start = Mock() + mock_service.stop = Mock() + + self.service_bus.register_service('test_service', mock_service) + + # Test service start orchestration + success = self.service_bus.orchestrate_service_start('test_service') + self.assertTrue(success) + mock_service.start.assert_called_once() + + # Test service stop orchestration + success = self.service_bus.orchestrate_service_stop('test_service') + self.assertTrue(success) + mock_service.stop.assert_called_once() + + # Test service restart orchestration + success = self.service_bus.orchestrate_service_restart('test_service') + self.assertTrue(success) + self.assertEqual(mock_service.start.call_count, 2) + self.assertEqual(mock_service.stop.call_count, 2) + + def test_event_history(self): + """Test event history functionality""" + self.service_bus.start() + + # Publish some events + for i in range(5): + self.service_bus.publish_event(EventType.SERVICE_STARTED, f'service_{i}', {'index': i}) + + # Wait for event processing + time.sleep(0.1) + + # Get event history + events = self.service_bus.get_event_history(limit=3) + self.assertEqual(len(events), 3) + + # Test filtering by event type + started_events = self.service_bus.get_event_history(EventType.SERVICE_STARTED, limit=2) + self.assertEqual(len(started_events), 2) + for event in started_events: + self.assertEqual(event.event_type, EventType.SERVICE_STARTED) + + # Test filtering by source + service_0_events = self.service_bus.get_event_history(source='service_0') + self.assertEqual(len(service_0_events), 1) + self.assertEqual(service_0_events[0].source, 'service_0') + + self.service_bus.stop() + +class TestLogManager(unittest.TestCase): + """Test the log manager functionality""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.log_dir = os.path.join(self.temp_dir, 'logs') + os.makedirs(self.log_dir, exist_ok=True) + + self.log_manager = LogManager(self.log_dir) + + def tearDown(self): + self.log_manager.stop() + shutil.rmtree(self.temp_dir) + + def test_initialization(self): + """Test log manager initialization""" + self.assertTrue(os.path.exists(self.log_dir)) + self.assertIsNotNone(self.log_manager.formatters) + self.assertIsNotNone(self.log_manager.handlers) + self.assertTrue(self.log_manager.running) + + def test_add_service_logger(self): + """Test adding service loggers""" + config = {'level': 'INFO', 'formatter': 'json', 'console': False} + self.log_manager.add_service_logger('test_service', config) + + self.assertIn('test_service', self.log_manager.service_loggers) + self.assertIn('test_service', self.log_manager.handlers) + + def test_get_service_logs(self): + """Test getting service logs""" + # Add service logger first + config = {'level': 'INFO', 'formatter': 'json', 'console': False} + self.log_manager.add_service_logger('test_service', config) + + # Create a test log file in the correct location + log_file = self.log_manager.log_dir / 'test_service.log' + with open(log_file, 'w') as f: + f.write('{"timestamp": "2024-01-01T10:00:00Z", "level": "INFO", "message": "Test log 1"}\n') + f.write('{"timestamp": "2024-01-01T10:01:00Z", "level": "ERROR", "message": "Test log 2"}\n') + f.write('{"timestamp": "2024-01-01T10:02:00Z", "level": "INFO", "message": "Test log 3"}\n') + + # Test getting all logs + logs = self.log_manager.get_service_logs_parsed('test_service', level='ALL', lines=3) + self.assertEqual(len(logs), 3) + + # Test filtering by level + error_logs = self.log_manager.get_service_logs_parsed('test_service', level='ERROR', lines=10) + self.assertEqual(len(error_logs), 1) + self.assertEqual(error_logs[0]['level'], 'ERROR') + + def test_search_logs(self): + """Test log search functionality""" + # Add service loggers first + config = {'level': 'INFO', 'formatter': 'json', 'console': False} + services = ['service1', 'service2'] + for service in services: + self.log_manager.add_service_logger(service, config) + + # Create test log files in the correct location + for service in services: + log_file = self.log_manager.log_dir / f'{service}.log' + with open(log_file, 'w') as f: + f.write('{"timestamp": "2024-01-01T10:00:00Z", "level": "INFO", "message": "Test message for ' + service + '"}\n') + f.write('{"timestamp": "2024-01-01T10:01:00Z", "level": "ERROR", "message": "Error in ' + service + '"}\n') + + # Test search across all services + results = self.log_manager.search_logs('Test message') + self.assertEqual(len(results), 2) + + # Test search with service filter + results = self.log_manager.search_logs('Error', services=['service1']) + self.assertEqual(len(results), 1) + self.assertIn('service1', results[0]['service']) + + # Test search with level filter + results = self.log_manager.search_logs('', level='ERROR') + self.assertEqual(len(results), 2) + for result in results: + self.assertEqual(result['level'], 'ERROR') + + def test_export_logs(self): + """Test log export functionality""" + # Add service logger first + config = {'level': 'INFO', 'formatter': 'json', 'console': False} + self.log_manager.add_service_logger('test_service', config) + + # Create test log file in the correct location + log_file = self.log_manager.log_dir / 'test_service.log' + with open(log_file, 'w') as f: + f.write('{"timestamp": "2024-01-01T10:00:00Z", "level": "INFO", "message": "Test log"}\n') + + # Test JSON export + json_export = self.log_manager.export_logs('json') + self.assertIsInstance(json_export, str) + self.assertIn('Test log', json_export) + + # Test CSV export + csv_export = self.log_manager.export_logs('csv') + self.assertIsInstance(csv_export, str) + self.assertIn('Test log', csv_export) + + # Test text export + text_export = self.log_manager.export_logs('text') + self.assertIsInstance(text_export, str) + self.assertIn('Test log', text_export) + + def test_log_statistics(self): + """Test log statistics functionality""" + # Create test log file + log_file = os.path.join(self.log_dir, 'test_service.log') + with open(log_file, 'w') as f: + f.write('{"timestamp": "2024-01-01T10:00:00Z", "level": "INFO", "message": "Info log"}\n') + f.write('{"timestamp": "2024-01-01T10:01:00Z", "level": "ERROR", "message": "Error log"}\n') + f.write('{"timestamp": "2024-01-01T10:02:00Z", "level": "WARNING", "message": "Warning log"}\n') + + # Get statistics + stats = self.log_manager.get_log_statistics('test_service') + self.assertIn('test_service', stats) + self.assertEqual(stats['test_service']['total_entries'], 3) + self.assertIn('level_counts', stats['test_service']) + self.assertEqual(stats['test_service']['level_counts']['INFO'], 1) + self.assertEqual(stats['test_service']['level_counts']['ERROR'], 1) + self.assertEqual(stats['test_service']['level_counts']['WARNING'], 1) + +class TestEnhancedCLI(unittest.TestCase): + """Test the enhanced CLI functionality""" + + def setUp(self): + self.cli = EnhancedCLI() + + def test_api_client(self): + """Test API client functionality""" + client = APIClient() + self.assertEqual(client.base_url, "http://localhost:3000/api") + self.assertIsNotNone(client.session) + + def test_cli_config_manager(self): + """Test CLI configuration manager""" + config_manager = CLIConfigManager() + self.assertIsNotNone(config_manager.config) + + # Test get/set + config_manager.set('test_key', 'test_value') + self.assertEqual(config_manager.get('test_key'), 'test_value') + + # Test export/import + exported = config_manager.export_config('json') + self.assertIsInstance(exported, str) + self.assertIn('test_key', exported) + + def test_cli_commands(self): + """Test CLI commands""" + # Test status command + with patch.object(self.cli.api_client, 'request') as mock_request: + mock_request.return_value = { + 'cell_name': 'test-cell', + 'domain': 'test.local', + 'peers_count': 2, + 'services': {'network': {'running': True}} + } + + # Capture print output + from io import StringIO + import sys + old_stdout = sys.stdout + sys.stdout = StringIO() + + try: + self.cli.do_status("") + output = sys.stdout.getvalue() + self.assertIn('test-cell', output) + self.assertIn('test.local', output) + finally: + sys.stdout = old_stdout + +class TestNetworkManagerIntegration(unittest.TestCase): + """Test NetworkManager integration with BaseServiceManager""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.data_dir = os.path.join(self.temp_dir, 'data') + self.config_dir = os.path.join(self.temp_dir, 'config') + os.makedirs(self.data_dir, exist_ok=True) + os.makedirs(self.config_dir, exist_ok=True) + + self.network_manager = NetworkManager(self.data_dir, self.config_dir) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_inheritance(self): + """Test that NetworkManager inherits from BaseServiceManager""" + self.assertIsInstance(self.network_manager, BaseServiceManager) + self.assertEqual(self.network_manager.service_name, 'network') + + def test_get_status(self): + """Test NetworkManager get_status method""" + status = self.network_manager.get_status() + self.assertIn('timestamp', status) + self.assertIn('network', status) + + def test_test_connectivity(self): + """Test NetworkManager test_connectivity method""" + connectivity = self.network_manager.test_connectivity() + self.assertIn('timestamp', connectivity) + self.assertIn('network', connectivity) + +def run_tests(): + """Run all tests""" + # Create test suite + test_suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestBaseServiceManager, + TestConfigManager, + TestServiceBus, + TestLogManager, + TestEnhancedCLI, + TestNetworkManagerIntegration + ] + + for test_class in test_classes: + tests = unittest.TestLoader().loadTestsFromTestCase(test_class) + test_suite.addTests(tests) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(test_suite) + + # Print summary + print(f"\n{'='*50}") + print(f"Test Summary:") + print(f"Tests run: {result.testsRun}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + print(f"Success rate: {((result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun * 100):.1f}%") + print(f"{'='*50}") + + return result.wasSuccessful() + +if __name__ == '__main__': + success = run_tests() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/tests/test_enhanced_cli_extra.py b/tests/test_enhanced_cli_extra.py new file mode 100644 index 0000000..d0a88a0 --- /dev/null +++ b/tests/test_enhanced_cli_extra.py @@ -0,0 +1,696 @@ +#!/usr/bin/env python3 +""" +Additional tests for enhanced_cli.py covering uncovered paths: +- EnhancedCLI.do_* methods +- EnhancedCLI._display_* methods +- EnhancedCLI.show_status, list_services, show_config +- EnhancedCLI.batch_start_services, batch_stop_services +- APIClient.request (PUT, DELETE branches, error handling) +- Module-level: batch_operations, export_config, import_config +""" + +import sys +import json +import tempfile +import os +import shutil +import unittest +from io import StringIO +from pathlib import Path +from unittest.mock import patch, MagicMock, call + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from enhanced_cli import EnhancedCLI, APIClient, ConfigManager, batch_operations, export_config, import_config + + +class TestAPIClientExtra(unittest.TestCase): + def setUp(self): + self.client = APIClient('http://localhost:3000/api') + + def test_request_put_calls_session_put(self): + mock_resp = MagicMock() + mock_resp.json.return_value = {'ok': True} + mock_resp.raise_for_status.return_value = None + with patch.object(self.client.session, 'put', return_value=mock_resp) as mock_put: + result = self.client.request('PUT', '/config', {'key': 'val'}) + mock_put.assert_called_once() + self.assertEqual(result, {'ok': True}) + + def test_request_delete_calls_session_delete(self): + mock_resp = MagicMock() + mock_resp.json.return_value = {'deleted': True} + mock_resp.raise_for_status.return_value = None + with patch.object(self.client.session, 'delete', return_value=mock_resp) as mock_del: + result = self.client.request('DELETE', '/peers/alice') + mock_del.assert_called_once() + self.assertEqual(result, {'deleted': True}) + + def test_request_exception_returns_none(self): + import requests as _req + with patch.object(self.client.session, 'get', + side_effect=_req.exceptions.RequestException('timeout')): + result = self.client.request('GET', '/status') + self.assertIsNone(result) + + +class TestEnhancedCLIDoMethods(unittest.TestCase): + def setUp(self): + self.cli = EnhancedCLI.__new__(EnhancedCLI) + self.cli.api_client = MagicMock() + self.cli.config_manager = MagicMock() + self.cli.current_service = None + self.cli.prompt = 'picell> ' + + # ── do_status ───────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_status_prints_status(self, mock_print): + self.cli.api_client.request.return_value = {'cell_name': 'mycel', 'peers_count': 2} + self.cli.do_status('') + self.assertTrue(any('mycel' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_do_status_prints_error_when_api_fails(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_status('') + mock_print.assert_any_call('❌ Failed to get status') + + # ── do_services ─────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_services_prints_services(self, mock_print): + self.cli.api_client.request.return_value = { + 'email': {'running': True, 'status': 'online'}} + self.cli.do_services('') + self.assertTrue(any('email' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_do_services_prints_error_on_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_services('') + mock_print.assert_any_call('❌ Failed to get services status') + + # ── do_peers ────────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_peers_empty_list_prints_message(self, mock_print): + self.cli.api_client.request.return_value = [] + self.cli.do_peers('') + mock_print.assert_any_call('📭 No peers configured.') + + @patch('builtins.print') + def test_do_peers_error_when_none_returned(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_peers('') + mock_print.assert_any_call('❌ Failed to fetch peers') + + @patch('builtins.print') + def test_do_peers_shows_peer_list(self, mock_print): + self.cli.api_client.request.return_value = [ + {'name': 'alice', 'ip': '10.0.0.2', 'public_key': 'abc123xyz', 'added_at': '2026-01-01'} + ] + self.cli.do_peers('') + self.assertTrue(any('alice' in str(c) for c in mock_print.call_args_list)) + + # ── do_add_peer ─────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_add_peer_too_few_args(self, mock_print): + self.cli.do_add_peer('alice') + mock_print.assert_any_call('❌ Usage: add_peer ') + + @patch('builtins.print') + def test_do_add_peer_success(self, mock_print): + self.cli.api_client.request.return_value = {'message': 'Added'} + self.cli.do_add_peer('alice 10.0.0.2 abc123key') + mock_print.assert_any_call('✅ Added') + + @patch('builtins.print') + def test_do_add_peer_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_add_peer('alice 10.0.0.2 abc123key') + mock_print.assert_any_call('❌ Failed to add peer') + + # ── do_remove_peer ──────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_remove_peer_no_arg(self, mock_print): + self.cli.do_remove_peer('') + mock_print.assert_any_call('❌ Usage: remove_peer ') + + @patch('builtins.print') + def test_do_remove_peer_success(self, mock_print): + self.cli.api_client.request.return_value = {'message': 'Removed'} + self.cli.do_remove_peer('alice') + mock_print.assert_any_call('✅ Removed') + + @patch('builtins.print') + def test_do_remove_peer_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_remove_peer('alice') + mock_print.assert_any_call('❌ Failed to remove peer') + + # ── do_config ───────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_config_shows_config(self, mock_print): + self.cli.api_client.request.return_value = {'cell_name': 'mycel'} + self.cli.do_config('') + self.assertTrue(any('cell_name' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_do_config_error_on_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_config('') + mock_print.assert_any_call('❌ Failed to get configuration') + + # ── do_update_config ────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_update_config_too_few_args(self, mock_print): + self.cli.do_update_config('cell_name') + mock_print.assert_any_call('❌ Usage: update_config ') + + @patch('builtins.print') + def test_do_update_config_success(self, mock_print): + self.cli.api_client.request.return_value = {'message': 'Updated'} + self.cli.do_update_config('cell_name newcell') + mock_print.assert_any_call('✅ Updated') + + @patch('builtins.print') + def test_do_update_config_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_update_config('cell_name newcell') + mock_print.assert_any_call('❌ Failed to update configuration') + + # ── do_logs ─────────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_logs_with_log_data(self, mock_print): + self.cli.api_client.request.return_value = {'log': 'line1\nline2\n'} + self.cli.do_logs('api 10') + self.assertTrue(any('line1' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_do_logs_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_logs('') + mock_print.assert_any_call('❌ Failed to get logs') + + # ── do_health ───────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_health_shows_history(self, mock_print): + self.cli.api_client.request.return_value = [ + {'timestamp': '2026-01-01T00:00:00', 'alerts': ['disk full']} + ] + self.cli.do_health('') + self.assertTrue(any('disk full' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_do_health_error(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_health('') + mock_print.assert_any_call('❌ Failed to get health data') + + # ── do_backup / do_restore / do_backups ─────────────────────────────────── + + @patch('builtins.print') + def test_do_backup_success(self, mock_print): + self.cli.api_client.request.return_value = {'backup_id': 'bk123'} + self.cli.do_backup('') + mock_print.assert_any_call('✅ Backup created: bk123') + + @patch('builtins.print') + def test_do_backup_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_backup('') + mock_print.assert_any_call('❌ Failed to create backup') + + @patch('builtins.print') + def test_do_restore_no_arg(self, mock_print): + self.cli.do_restore('') + mock_print.assert_any_call('❌ Usage: restore ') + + @patch('builtins.print') + def test_do_restore_success(self, mock_print): + self.cli.api_client.request.return_value = {'ok': True} + self.cli.do_restore('bk123') + mock_print.assert_any_call('✅ Configuration restored from backup: bk123') + + @patch('builtins.print') + def test_do_restore_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_restore('bk123') + mock_print.assert_any_call('❌ Failed to restore configuration') + + @patch('builtins.print') + def test_do_backups_success(self, mock_print): + self.cli.api_client.request.return_value = [ + {'backup_id': 'bk1', 'timestamp': '2026-01-01', 'services': ['dns']} + ] + self.cli.do_backups('') + self.assertTrue(any('bk1' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_do_backups_failure(self, mock_print): + self.cli.api_client.request.return_value = None + self.cli.do_backups('') + mock_print.assert_any_call('❌ Failed to get backups') + + # ── do_service ──────────────────────────────────────────────────────────── + + @patch('builtins.print') + def test_do_service_no_arg(self, mock_print): + self.cli.do_service('') + mock_print.assert_any_call('❌ Usage: service ') + + @patch('builtins.print') + def test_do_service_sets_context(self, mock_print): + self.cli.do_service('email') + self.assertEqual(self.cli.current_service, 'email') + self.assertEqual(self.cli.prompt, 'picell:email> ') + + # ── do_exit / do_quit / do_EOF ─────────────────────────────────────────── + + @patch('builtins.print') + def test_do_exit_returns_true(self, mock_print): + result = self.cli.do_exit('') + self.assertTrue(result) + + @patch('builtins.print') + def test_do_quit_delegates_to_exit(self, mock_print): + result = self.cli.do_quit('') + self.assertTrue(result) + + @patch('builtins.print') + def test_do_eof_returns_true(self, mock_print): + result = self.cli.do_EOF('') + self.assertTrue(result) + + # ── show_status / list_services / show_config ───────────────────────────── + + @patch('builtins.print') + def test_show_status(self, mock_print): + self.cli.api_client.get = MagicMock(return_value={'cell_name': 'mycel', 'peers_count': 1}) + self.cli.show_status() + self.assertTrue(any('mycel' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_show_status_handles_none(self, mock_print): + self.cli.api_client.get = MagicMock(return_value=None) + self.cli.show_status() # Should not raise + + @patch('builtins.print') + def test_list_services(self, mock_print): + self.cli.api_client.get = MagicMock(return_value={'email': {'running': True}}) + self.cli.list_services() + mock_print.assert_called_once() + + @patch('builtins.print') + def test_show_config(self, mock_print): + self.cli.api_client.get = MagicMock(return_value={'cell_name': 'mycel'}) + self.cli.show_config() + self.assertTrue(any('cell_name' in str(c) for c in mock_print.call_args_list)) + + # ── batch_start_services / batch_stop_services ──────────────────────────── + + @patch('builtins.print') + def test_batch_start_services(self, mock_print): + self.cli.api_client.post = MagicMock(return_value={'ok': True}) + self.cli.batch_start_services(['email', 'dns']) + self.assertEqual(self.cli.api_client.post.call_count, 2) + + @patch('builtins.print') + def test_batch_stop_services(self, mock_print): + self.cli.api_client.post = MagicMock(return_value={'ok': True}) + self.cli.batch_stop_services(['email']) + self.assertEqual(self.cli.api_client.post.call_count, 1) + + +class TestDisplayMethods(unittest.TestCase): + def setUp(self): + self.cli = EnhancedCLI.__new__(EnhancedCLI) + + @patch('builtins.print') + def test_display_status_with_list_services(self, mock_print): + self.cli._display_status({'cell_name': 'mycel', 'services': ['dns', 'dhcp']}) + self.assertTrue(any('dns' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_display_status_with_dict_services(self, mock_print): + self.cli._display_status({ + 'cell_name': 'mycel', + 'services': { + 'email': {'running': True, 'status': 'online'}, + 'dns': False # non-dict service + } + }) + self.assertTrue(any('email' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_display_services_with_non_dict_status(self, mock_print): + self.cli._display_services({'email': True, 'timestamp': '2026-01-01'}) + self.assertTrue(any('email' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_display_peers(self, mock_print): + self.cli._display_peers([ + {'name': 'alice', 'ip': '10.0.0.2', 'public_key': 'abcdefghijklmnopqrst', 'added_at': '2026-01-01'} + ]) + self.assertTrue(any('alice' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_display_health_with_alerts(self, mock_print): + self.cli._display_health([ + {'timestamp': '2026-01-01T00:00:00', 'alerts': ['disk full']} + ]) + self.assertTrue(any('disk full' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_display_health_no_alerts(self, mock_print): + self.cli._display_health([ + {'timestamp': '2026-01-01T00:00:00', 'alerts': []} + ]) + self.assertTrue(any('2026-01-01' in str(c) for c in mock_print.call_args_list)) + + +class TestModuleLevelFunctions(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp) + + @patch('builtins.print') + def test_batch_operations(self, mock_print): + """batch_operations executes commands via EnhancedCLI.onecmd.""" + # Mock requests to avoid actual HTTP calls + with patch('enhanced_cli.requests.get', side_effect=Exception('no server')): + batch_operations(['status', 'config']) + # Should have printed headers for both commands + self.assertTrue(mock_print.call_count >= 2) + + def test_export_config_json(self): + with patch('enhanced_cli.ConfigManager.__init__', lambda self, *a, **kw: setattr(self, 'config', {'key': 'val'}) or None): + with patch.object(ConfigManager, '_load_config', return_value={'key': 'val'}): + result = export_config('json') + self.assertIn('key', result) + + def test_import_config_success(self): + config_file = os.path.join(self.tmp, 'config.json') + with open(config_file, 'w') as f: + json.dump({'key': 'val'}, f) + # Use real ConfigManager with temp dir + with patch('enhanced_cli.ConfigManager') as MockCM: + mock_instance = MagicMock() + MockCM.return_value = mock_instance + result = import_config(config_file, 'json') + self.assertTrue(result) + mock_instance.import_config.assert_called_once() + + def test_import_config_nonexistent_file_returns_false(self): + result = import_config('/nonexistent/config.json', 'json') + self.assertFalse(result) + + +class TestConfigManagerJsonPath(unittest.TestCase): + """Cover ConfigManager branches that use .json suffix.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_init_with_json_path_sets_config_file_directly(self): + p = os.path.join(self.tmp, 'cfg.json') + cm = ConfigManager(p) + self.assertEqual(str(cm.config_file), p) + self.assertEqual(cm.config_dir, cm.config_file.parent) + + def test_load_config_reads_json_file(self): + p = os.path.join(self.tmp, 'cfg.json') + with open(p, 'w') as f: + json.dump({'hello': 'world'}, f) + cm = ConfigManager(p) + self.assertEqual(cm.config.get('hello'), 'world') + + def test_load_config_exception_returns_empty(self): + p = os.path.join(self.tmp, 'cfg.json') + # Write invalid JSON + with open(p, 'w') as f: + f.write('not json {{') + cm = ConfigManager(p) + self.assertEqual(cm.config, {}) + + def test_save_config_writes_json(self): + p = os.path.join(self.tmp, 'cfg.json') + cm = ConfigManager(p) + cm.config = {'saved': True} + cm.save() + with open(p) as f: + data = json.load(f) + self.assertTrue(data.get('saved')) + + def test_save_config_exception_does_not_raise(self): + p = os.path.join(self.tmp, 'cfg.json') + cm = ConfigManager(p) + # Make the file unwritable by mocking open to raise + with patch('builtins.open', side_effect=OSError('disk full')): + cm.save() # must not raise + + def test_export_config_yaml_format(self): + p = os.path.join(self.tmp, 'cfg.json') + cm = ConfigManager(p) + cm.config = {'key': 'value'} + result = cm.export_config('yaml') + self.assertIn('key', result) + + def test_export_config_unsupported_format_raises(self): + p = os.path.join(self.tmp, 'cfg.json') + cm = ConfigManager(p) + with self.assertRaises(ValueError): + cm.export_config('xml') + + def test_import_config_yaml(self): + p = os.path.join(self.tmp, 'cfg.json') + cm = ConfigManager(p) + cm.import_config('key: value\n', 'yaml') + self.assertEqual(cm.config.get('key'), 'value') + + def test_import_config_unsupported_format_prints_error(self): + p = os.path.join(self.tmp, 'cfg.json') + cm = ConfigManager(p) + # Should not raise even with unsupported format (caught internally) + with patch('builtins.print') as mock_print: + cm.import_config('...', 'xml') + self.assertTrue(any('Error' in str(c) for c in mock_print.call_args_list)) + + +class TestEnhancedCLIGetPost(unittest.TestCase): + """Cover EnhancedCLI.get() and .post() HTTP shortcut methods.""" + + def setUp(self): + self.cli = EnhancedCLI.__new__(EnhancedCLI) + self.cli.api_client = MagicMock() + self.cli.api_client.base_url = 'http://localhost:3000/api' + + @patch('enhanced_cli.requests.get') + def test_get_returns_json_on_success(self, mock_get): + mock_resp = MagicMock() + mock_resp.json.return_value = {'ok': True} + mock_resp.raise_for_status.return_value = None + mock_get.return_value = mock_resp + result = self.cli.get('/status') + self.assertEqual(result, {'ok': True}) + + @patch('enhanced_cli.requests.get') + def test_get_returns_none_on_exception(self, mock_get): + mock_get.side_effect = Exception('connection refused') + result = self.cli.get('/status') + self.assertIsNone(result) + + @patch('enhanced_cli.requests.post') + def test_post_returns_json_on_success(self, mock_post): + mock_resp = MagicMock() + mock_resp.json.return_value = {'created': True} + mock_resp.raise_for_status.return_value = None + mock_post.return_value = mock_resp + result = self.cli.post('/peers', {'name': 'alice'}) + self.assertEqual(result, {'created': True}) + + @patch('enhanced_cli.requests.post') + def test_post_returns_none_on_exception(self, mock_post): + mock_post.side_effect = Exception('timeout') + result = self.cli.post('/peers', {'name': 'alice'}) + self.assertIsNone(result) + + +class TestShowStatusExceptionPath(unittest.TestCase): + """Cover show_status() exception branch.""" + + def setUp(self): + self.cli = EnhancedCLI.__new__(EnhancedCLI) + self.cli.api_client = MagicMock() + + @patch('builtins.print') + def test_show_status_exception_prints_error(self, mock_print): + self.cli.api_client.get.side_effect = RuntimeError('api down') + self.cli.show_status() + self.assertTrue(any('Error' in str(c) for c in mock_print.call_args_list)) + + +class TestInteractiveMode(unittest.TestCase): + """Cover interactive_mode() loop.""" + + def setUp(self): + self.cli = EnhancedCLI.__new__(EnhancedCLI) + self.cli.api_client = MagicMock() + self.cli.config_manager = MagicMock() + self.cli.current_service = None + self.cli.prompt = 'picell> ' + + @patch('builtins.print') + def test_interactive_mode_exits_on_quit(self, mock_print): + with patch('builtins.input', side_effect=['quit']): + self.cli.interactive_mode() + mock_print.assert_any_call('Entering interactive mode. Type \'quit\' to exit.') + + @patch('builtins.print') + def test_interactive_mode_exits_on_eof(self, mock_print): + with patch('builtins.input', side_effect=EOFError): + self.cli.interactive_mode() + + +class TestMainFunction(unittest.TestCase): + """Cover main() argument branches.""" + + def _run_main(self, args): + import sys as _sys + old_argv = _sys.argv + _sys.argv = ['enhanced_cli'] + args + try: + from enhanced_cli import main + with patch('builtins.print'): + try: + main() + except SystemExit: + pass + finally: + _sys.argv = old_argv + + def test_main_no_args_prints_help(self): + with patch('enhanced_cli.argparse.ArgumentParser.print_help') as mock_help: + self._run_main([]) + mock_help.assert_called_once() + + def test_main_status_flag(self): + with patch('enhanced_cli.EnhancedCLI') as MockCLI: + mock_cli = MagicMock() + MockCLI.return_value = mock_cli + self._run_main(['--status']) + mock_cli.do_status.assert_called_once_with('') + + def test_main_services_flag(self): + with patch('enhanced_cli.EnhancedCLI') as MockCLI: + mock_cli = MagicMock() + MockCLI.return_value = mock_cli + self._run_main(['--services']) + mock_cli.do_services.assert_called_once_with('') + + def test_main_peers_flag(self): + with patch('enhanced_cli.EnhancedCLI') as MockCLI: + mock_cli = MagicMock() + MockCLI.return_value = mock_cli + self._run_main(['--peers']) + mock_cli.do_peers.assert_called_once_with('') + + def test_main_logs_flag(self): + with patch('enhanced_cli.EnhancedCLI') as MockCLI: + mock_cli = MagicMock() + MockCLI.return_value = mock_cli + self._run_main(['--logs', 'api']) + mock_cli.do_logs.assert_called_once_with('api') + + def test_main_health_flag(self): + with patch('enhanced_cli.EnhancedCLI') as MockCLI: + mock_cli = MagicMock() + MockCLI.return_value = mock_cli + self._run_main(['--health']) + mock_cli.do_health.assert_called_once_with('') + + def test_main_batch_flag(self): + with patch('enhanced_cli.batch_operations') as mock_batch: + self._run_main(['--batch', 'status', 'config']) + mock_batch.assert_called_once_with(['status', 'config']) + + def test_main_export_config_flag(self): + with patch('enhanced_cli.export_config', return_value='{}') as mock_export: + self._run_main(['--export-config', 'json']) + mock_export.assert_called_once_with('json') + + def test_main_import_config_json_file(self): + with patch('enhanced_cli.import_config', return_value=True) as mock_import: + self._run_main(['--import-config', 'config.json']) + mock_import.assert_called_once_with('config.json', 'json') + + def test_main_import_config_yaml_file(self): + with patch('enhanced_cli.import_config', return_value=True) as mock_import: + self._run_main(['--import-config', 'config.yaml']) + mock_import.assert_called_once_with('config.yaml', 'yaml') + + def test_main_wizard_flag(self): + with patch('enhanced_cli.service_wizard') as mock_wizard: + self._run_main(['--wizard', 'email']) + mock_wizard.assert_called_once_with('email') + + +class TestServiceWizardFunction(unittest.TestCase): + """Cover service_wizard() branches.""" + + def _call_wizard(self, service, inputs): + from enhanced_cli import service_wizard + with patch('builtins.input', side_effect=inputs): + with patch('builtins.print'): + with patch('enhanced_cli.APIClient') as MockAPI: + mock_client = MagicMock() + mock_client.request.return_value = {'ok': True} + MockAPI.return_value = mock_client + service_wizard(service) + return mock_client + + def test_service_wizard_network_calls_api(self): + client = self._call_wizard('network', ['53', '10.0.0.100-200', '', '']) + client.request.assert_called_once() + + def test_service_wizard_wireguard_calls_api(self): + client = self._call_wizard('wireguard', ['51820', '10.0.0.1/24']) + client.request.assert_called_once() + + def test_service_wizard_email_calls_api(self): + client = self._call_wizard('email', ['example.com', '587', '993']) + client.request.assert_called_once() + + @patch('builtins.print') + def test_service_wizard_unknown_service_prints_error(self, mock_print): + from enhanced_cli import service_wizard + service_wizard('unknown_service') + self.assertTrue(any('Wizard not available' in str(c) for c in mock_print.call_args_list)) + + @patch('builtins.print') + def test_service_wizard_api_failure_prints_error(self, mock_print): + from enhanced_cli import service_wizard + with patch('builtins.input', side_effect=['53', '10.0.0.100-200', '', '']): + with patch('enhanced_cli.APIClient') as MockAPI: + mock_client = MagicMock() + mock_client.request.return_value = None + MockAPI.return_value = mock_client + service_wizard('network') + self.assertTrue(any('Failed' in str(c) for c in mock_print.call_args_list)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_file_manager_extra.py b/tests/test_file_manager_extra.py new file mode 100644 index 0000000..d603cd0 --- /dev/null +++ b/tests/test_file_manager_extra.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +""" +Additional tests for FileManager covering uncovered paths: +- _safe_path traversal rejection +- create_user: duplicate (file already has user), invalid username +- delete_user: not in file, invalid username +- list_users with actual htpasswd file +- get_users from users.json +- _get_user_storage_info for nonexistent dir +- _list_user_folders +- create_folder / delete_folder edge cases +- upload_file / download_file / delete_file edge cases +- list_files edge cases +- backup_user_files invalid username +- restore_user_files invalid username / nonexistent backup +- get_status in docker and non-docker mode +- _test_filesystem_access +- _test_user_authentication +- test_connectivity +- get_webdav_status (mocked subprocess) +""" + +import sys +import os +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from file_manager import FileManager + + +class TestFileManagerExtra(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.data_dir = os.path.join(self.tmp, 'data') + self.config_dir = os.path.join(self.tmp, 'config') + os.makedirs(self.data_dir, exist_ok=True) + os.makedirs(self.config_dir, exist_ok=True) + self.fm = FileManager(data_dir=self.data_dir, config_dir=self.config_dir) + + def tearDown(self): + shutil.rmtree(self.tmp) + + # ── _safe_path ────────────────────────────────────────────────────────── + + def test_safe_path_traversal_raises(self): + with self.assertRaises(ValueError): + self.fm._safe_path('alice', '../../../etc/passwd') + + def test_safe_path_invalid_username_raises(self): + with self.assertRaises(ValueError): + self.fm._safe_path('user with space', 'file.txt') + + def test_safe_path_valid_returns_path_under_files_dir(self): + path = self.fm._safe_path('alice', 'Documents', 'readme.txt') + self.assertTrue(path.startswith(self.fm.files_dir)) + + # ── create_user ───────────────────────────────────────────────────────── + + def test_create_user_invalid_username_chars_returns_false(self): + result = self.fm.create_user('bad user!', 'password') + self.assertFalse(result) + + def test_create_user_creates_default_folders(self): + result = self.fm.create_user('alice', 'secret') + self.assertTrue(result) + user_dir = os.path.join(self.fm.files_dir, 'alice') + for folder in ('Documents', 'Pictures', 'Music', 'Videos', 'Downloads'): + self.assertTrue(os.path.isdir(os.path.join(user_dir, folder))) + + def test_create_user_bcrypt_hash_uses_2y_prefix(self): + self.fm.create_user('alice', 'secret') + auth_file = os.path.join(self.fm.webdav_dir, 'users') + with open(auth_file) as f: + line = f.read() + self.assertIn('$2y$', line) + self.assertNotIn('$2b$', line) + + def test_create_user_appends_to_existing_file(self): + self.fm.create_user('alice', 'secret') + self.fm.create_user('bob', 'secret2') + auth_file = os.path.join(self.fm.webdav_dir, 'users') + with open(auth_file) as f: + lines = [l for l in f if l.strip()] + self.assertEqual(len(lines), 2) + + # ── delete_user ───────────────────────────────────────────────────────── + + def test_delete_user_invalid_username_returns_false(self): + result = self.fm.delete_user('bad user!') + self.assertFalse(result) + + def test_delete_user_not_in_auth_file_still_returns_true(self): + """Delete succeeds even if user was never in the auth file.""" + result = self.fm.delete_user('nobody') + self.assertTrue(result) + + def test_delete_user_removes_only_matching_line(self): + self.fm.create_user('alice', 'secret') + self.fm.create_user('bob', 'secret2') + self.fm.delete_user('alice') + auth_file = os.path.join(self.fm.webdav_dir, 'users') + with open(auth_file) as f: + content = f.read() + self.assertNotIn('alice:', content) + self.assertIn('bob:', content) + + # ── list_users ────────────────────────────────────────────────────────── + + def test_list_users_returns_usernames_from_auth_file(self): + self.fm.create_user('alice', 'secret') + self.fm.create_user('bob', 'secret2') + users = self.fm.list_users() + usernames = [u['username'] for u in users] + self.assertIn('alice', usernames) + self.assertIn('bob', usernames) + + def test_list_users_skips_malformed_lines(self): + auth_file = os.path.join(self.fm.webdav_dir, 'users') + with open(auth_file, 'w') as f: + f.write('alicehash\n') # no colon + f.write('bob:hash\n') + users = self.fm.list_users() + usernames = [u['username'] for u in users] + self.assertNotIn('alicehash', usernames) + self.assertIn('bob', usernames) + + # ── get_users ─────────────────────────────────────────────────────────── + + def test_get_users_returns_empty_when_no_file(self): + result = self.fm.get_users() + self.assertEqual(result, []) + + def test_get_users_returns_list_from_json(self): + import json + webdav_dir = os.path.join(self.config_dir, 'webdav') + os.makedirs(webdav_dir, exist_ok=True) + users_file = os.path.join(webdav_dir, 'users.json') + with open(users_file, 'w') as f: + json.dump([{'username': 'alice'}], f) + result = self.fm.get_users() + self.assertEqual(len(result), 1) + self.assertEqual(result[0]['username'], 'alice') + + # ── _get_user_storage_info ─────────────────────────────────────────────── + + def test_storage_info_nonexistent_dir_returns_zeros(self): + info = self.fm._get_user_storage_info('nobody') + self.assertEqual(info['total_files'], 0) + self.assertEqual(info['total_size_bytes'], 0) + + def test_storage_info_counts_files(self): + self.fm.create_user('alice', 'secret') + self.fm.upload_file('alice', 'Documents/a.txt', b'hello') + self.fm.upload_file('alice', 'Documents/b.txt', b'world') + info = self.fm._get_user_storage_info('alice') + self.assertGreaterEqual(info['total_files'], 2) + + # ── _list_user_folders ─────────────────────────────────────────────────── + + def test_list_user_folders_returns_list(self): + self.fm.create_user('alice', 'secret') + folders = self.fm._list_user_folders('alice') + self.assertIsInstance(folders, list) + names = [f['name'] for f in folders] + self.assertIn('Documents', names) + + def test_list_user_folders_empty_for_nonexistent_user(self): + folders = self.fm._list_user_folders('nobody') + self.assertEqual(folders, []) + + # ── create_folder / delete_folder edge cases ──────────────────────────── + + def test_create_folder_traversal_returns_false(self): + self.fm.create_user('alice', 'secret') + result = self.fm.create_folder('alice', '../../../tmp/evil') + self.assertFalse(result) + + def test_delete_folder_nonexistent_returns_false(self): + self.fm.create_user('alice', 'secret') + result = self.fm.delete_folder('alice', 'NoSuchFolder') + self.assertFalse(result) + + def test_delete_folder_traversal_returns_false(self): + self.fm.create_user('alice', 'secret') + result = self.fm.delete_folder('alice', '../../../tmp/evil') + self.assertFalse(result) + + # ── upload / download / delete edge cases ─────────────────────────────── + + def test_download_file_nonexistent_returns_none(self): + self.fm.create_user('alice', 'secret') + result = self.fm.download_file('alice', 'nope.txt') + self.assertIsNone(result) + + def test_delete_file_nonexistent_returns_false(self): + self.fm.create_user('alice', 'secret') + result = self.fm.delete_file('alice', 'nope.txt') + self.assertFalse(result) + + def test_upload_file_creates_parent_dirs(self): + self.fm.create_user('alice', 'secret') + result = self.fm.upload_file('alice', 'deep/nested/file.txt', b'data') + self.assertTrue(result) + path = self.fm._safe_path('alice', 'deep/nested/file.txt') + self.assertTrue(os.path.exists(path)) + + # ── list_files ────────────────────────────────────────────────────────── + + def test_list_files_shows_files_and_dirs(self): + self.fm.create_user('alice', 'secret') + self.fm.upload_file('alice', 'docs/readme.txt', b'hello') + files = self.fm.list_files('alice', 'docs') + self.assertEqual(len(files), 1) + self.assertEqual(files[0]['type'], 'file') + + def test_list_files_empty_folder(self): + self.fm.create_user('alice', 'secret') + files = self.fm.list_files('alice', 'Videos') + self.assertIsInstance(files, list) + self.assertEqual(len(files), 0) + + # ── backup / restore ──────────────────────────────────────────────────── + + def test_backup_user_files_invalid_username_returns_false(self): + result = self.fm.backup_user_files('../../etc', '/tmp/backup') + self.assertFalse(result) + + def test_backup_user_files_empty_username_returns_false(self): + result = self.fm.backup_user_files('', '/tmp/backup') + self.assertFalse(result) + + def test_restore_user_files_invalid_username_returns_false(self): + result = self.fm.restore_user_files('../../etc', '/tmp/backup') + self.assertFalse(result) + + def test_restore_user_files_nonexistent_backup_returns_false(self): + result = self.fm.restore_user_files('alice', '/tmp/nonexistent_backup') + self.assertFalse(result) + + # ── _test_filesystem_access ───────────────────────────────────────────── + + def test_test_filesystem_access_succeeds(self): + result = self.fm._test_filesystem_access() + self.assertTrue(result['success']) + self.assertTrue(result['read_write']) + + # ── _test_user_authentication ─────────────────────────────────────────── + + def test_test_user_authentication_no_file_returns_success(self): + result = self.fm._test_user_authentication() + self.assertTrue(result['success']) + self.assertEqual(result['users_count'], 0) + + def test_test_user_authentication_counts_users(self): + self.fm.create_user('alice', 'secret') + self.fm.create_user('bob', 'secret2') + result = self.fm._test_user_authentication() + self.assertTrue(result['success']) + self.assertEqual(result['users_count'], 2) + + # ── test_connectivity ─────────────────────────────────────────────────── + + @patch('requests.get') + @patch('requests.options') + def test_test_connectivity_returns_dict_with_expected_keys(self, mock_options, mock_get): + mock_get.side_effect = Exception('connection refused') + mock_options.side_effect = Exception('connection refused') + result = self.fm.test_connectivity() + self.assertIn('webdav_connectivity', result) + self.assertIn('filesystem_access', result) + self.assertIn('user_authentication', result) + self.assertIn('success', result) + + # ── get_status ────────────────────────────────────────────────────────── + + @patch.dict(os.environ, {'DOCKER_CONTAINER': 'true'}) + def test_get_status_docker_mode_returns_dict(self): + result = self.fm.get_status() + self.assertIn('running', result) + self.assertIn('status', result) + + @patch.dict(os.environ, {'DOCKER_CONTAINER': 'false'}) + @patch('subprocess.run') + @patch('requests.get') + @patch('requests.options') + def test_get_status_non_docker_mode(self, mock_opt, mock_get, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='', stderr='') + mock_get.side_effect = Exception('refused') + mock_opt.side_effect = Exception('refused') + result = self.fm.get_status() + self.assertIn('running', result) + + # ── _get_total_storage_used ───────────────────────────────────────────── + + def test_get_total_storage_used_empty_dir(self): + result = self.fm._get_total_storage_used() + self.assertEqual(result['total_files'], 0) + self.assertEqual(result['total_size_bytes'], 0) + + def test_get_total_storage_used_with_files(self): + self.fm.create_user('alice', 'secret') + self.fm.upload_file('alice', 'Documents/test.txt', b'hello world') + result = self.fm._get_total_storage_used() + self.assertGreater(result['total_files'], 0) + self.assertGreater(result['total_size_bytes'], 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_firewall_manager_extra.py b/tests/test_firewall_manager_extra.py new file mode 100644 index 0000000..f37cb93 --- /dev/null +++ b/tests/test_firewall_manager_extra.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Additional tests for firewall_manager.py covering missed lines: +- _run() exception path (lines 52-54) +- ensure_caddy_virtual_ips() add-failure branch and exception path +- _rule_exists, _ensure_rule, _delete_rule +- reload_coredns (success, failure, exception) +- apply_all_dns_rules +- _service_tag +- apply_service_rules +- clear_service_rules +""" + +import sys +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock, call +import subprocess + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +import firewall_manager + + +def _make_proc(returncode=0, stdout='', stderr=''): + p = MagicMock() + p.returncode = returncode + p.stdout = stdout + p.stderr = stderr + return p + + +class TestRunFunction(unittest.TestCase): + def test_run_success_returns_result(self): + proc = _make_proc(returncode=0, stdout='ok') + with patch('subprocess.run', return_value=proc): + result = firewall_manager._run(['echo', 'ok']) + self.assertEqual(result.returncode, 0) + + def test_run_nonzero_with_check_logs_warning(self): + proc = _make_proc(returncode=1, stderr='error') + with patch('subprocess.run', return_value=proc): + result = firewall_manager._run(['false'], check=True) + self.assertEqual(result.returncode, 1) + + def test_run_exception_reraises(self): + with patch('subprocess.run', side_effect=subprocess.TimeoutExpired(['cmd'], 1)): + with self.assertRaises(subprocess.TimeoutExpired): + firewall_manager._run(['cmd']) + + +class TestEnsureCaddyVirtualIps(unittest.TestCase): + def test_exception_returns_false(self): + with patch.object(firewall_manager, '_caddy_exec', side_effect=RuntimeError('no docker')): + result = firewall_manager.ensure_caddy_virtual_ips() + self.assertFalse(result) + + def test_ip_already_present_skips_add(self): + # All IPs are already in the existing output + all_ips = ' '.join(firewall_manager.SERVICE_IPS.values()) + mock_result = _make_proc(returncode=0, stdout=all_ips) + with patch.object(firewall_manager, '_caddy_exec', return_value=mock_result) as mock_exec: + result = firewall_manager.ensure_caddy_virtual_ips() + self.assertTrue(result) + # ip addr show was called once; no add calls + self.assertEqual(mock_exec.call_count, 1) + + def test_missing_ip_triggers_add(self): + # No IPs in stdout → all IPs need to be added + calls_made = [] + + def fake_caddy_exec(args): + calls_made.append(args) + return _make_proc(returncode=0, stdout='') + + with patch.object(firewall_manager, '_caddy_exec', side_effect=fake_caddy_exec): + result = firewall_manager.ensure_caddy_virtual_ips() + self.assertTrue(result) + # First call is ip addr show; subsequent calls are ip addr add + self.assertGreater(len(calls_made), 1) + + def test_add_failure_logs_warning(self): + # First call (ip addr show) returns empty; subsequent calls (ip addr add) fail + call_count = [0] + + def fake_caddy_exec(args): + call_count[0] += 1 + if call_count[0] == 1: + return _make_proc(returncode=0, stdout='') + return _make_proc(returncode=1, stderr='failed to add IP') + + with patch.object(firewall_manager, '_caddy_exec', side_effect=fake_caddy_exec): + result = firewall_manager.ensure_caddy_virtual_ips() + self.assertTrue(result) # Function still returns True even on add failure + + +class TestRuleHelpers(unittest.TestCase): + def test_rule_exists_returns_true_when_returncode_0(self): + with patch.object(firewall_manager, '_iptables', return_value=_make_proc(returncode=0)): + result = firewall_manager._rule_exists('FORWARD', ['-j', 'ACCEPT']) + self.assertTrue(result) + + def test_rule_exists_returns_false_when_nonzero(self): + with patch.object(firewall_manager, '_iptables', return_value=_make_proc(returncode=1)): + result = firewall_manager._rule_exists('FORWARD', ['-j', 'ACCEPT']) + self.assertFalse(result) + + def test_ensure_rule_inserts_when_not_present(self): + calls = [] + def fake_iptables(args, check=False): + calls.append(args[0]) + if args[0] == '-C': + return _make_proc(returncode=1) + return _make_proc(returncode=0) + + with patch.object(firewall_manager, '_iptables', side_effect=fake_iptables): + firewall_manager._ensure_rule('FORWARD', ['-j', 'ACCEPT']) + self.assertIn('-I', calls) + + def test_ensure_rule_skips_insert_when_already_present(self): + with patch.object(firewall_manager, '_iptables', return_value=_make_proc(returncode=0)) as mock_ipt: + firewall_manager._ensure_rule('FORWARD', ['-j', 'ACCEPT']) + # Only the -C check call was made + self.assertEqual(mock_ipt.call_count, 1) + + def test_delete_rule_calls_delete_while_exists(self): + check_count = [0] + def fake_iptables(args, check=False): + if args[0] == '-C': + check_count[0] += 1 + # Rule exists on first check, gone after first delete + return _make_proc(returncode=0 if check_count[0] == 1 else 1) + # -D delete call: return success + return _make_proc(returncode=0) + + with patch.object(firewall_manager, '_iptables', side_effect=fake_iptables): + firewall_manager._delete_rule('FORWARD', ['-j', 'ACCEPT']) + # Should have checked twice (once found, once not found) and deleted once + self.assertEqual(check_count[0], 2) + + +class TestReloadCoreDns(unittest.TestCase): + def test_success_returns_true(self): + with patch.object(firewall_manager, '_run', return_value=_make_proc(returncode=0)): + result = firewall_manager.reload_coredns() + self.assertTrue(result) + + def test_nonzero_returncode_returns_false(self): + with patch.object(firewall_manager, '_run', return_value=_make_proc(returncode=1, stderr='not found')): + result = firewall_manager.reload_coredns() + self.assertFalse(result) + + def test_exception_returns_false(self): + with patch.object(firewall_manager, '_run', side_effect=RuntimeError('no docker')): + result = firewall_manager.reload_coredns() + self.assertFalse(result) + + +class TestApplyAllDnsRules(unittest.TestCase): + def test_generates_corefile_and_calls_reload_on_success(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=True) as mock_gen, \ + patch.object(firewall_manager, 'reload_coredns', return_value=True) as mock_reload: + result = firewall_manager.apply_all_dns_rules([], '/tmp/Corefile') + self.assertTrue(result) + mock_gen.assert_called_once() + mock_reload.assert_called_once() + + def test_does_not_call_reload_when_generate_fails(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=False), \ + patch.object(firewall_manager, 'reload_coredns') as mock_reload: + result = firewall_manager.apply_all_dns_rules([], '/tmp/Corefile') + self.assertFalse(result) + mock_reload.assert_not_called() + + +class TestServiceTag(unittest.TestCase): + def test_lowercase_and_replace_special_chars(self): + tag = firewall_manager._service_tag('my-service_v2!') + self.assertEqual(tag, 'pic-svc-my-service-v2-') + + def test_simple_id(self): + tag = firewall_manager._service_tag('gitea') + self.assertEqual(tag, 'pic-svc-gitea') + + +class TestApplyServiceRules(unittest.TestCase): + def test_applies_accept_rules_via_iptables(self): + calls = [] + def fake_iptables(args, check=False): + calls.append(args) + return _make_proc(returncode=0) + + rules = [{'type': 'ACCEPT', 'dest_ip': '10.20.0.5', 'dest_port': 80, 'proto': 'tcp'}] + with patch.object(firewall_manager, '_iptables', side_effect=fake_iptables), \ + patch.object(firewall_manager, 'clear_service_rules'): + result = firewall_manager.apply_service_rules('gitea', '10.20.0.5', rules) + self.assertTrue(result) + self.assertTrue(any('FORWARD' in str(c) for c in calls)) + + def test_skips_non_accept_rules(self): + calls = [] + rules = [{'type': 'DROP', 'dest_ip': '10.20.0.5', 'dest_port': 80, 'proto': 'tcp'}] + with patch.object(firewall_manager, '_iptables', side_effect=lambda *a, **kw: calls.append(a) or _make_proc()), \ + patch.object(firewall_manager, 'clear_service_rules'): + result = firewall_manager.apply_service_rules('gitea', '10.20.0.5', rules) + self.assertTrue(result) + self.assertEqual(len(calls), 0) + + def test_service_ip_placeholder_substituted(self): + captured = [] + def fake_iptables(args, check=False): + captured.extend(args) + return _make_proc(returncode=0) + + rules = [{'type': 'ACCEPT', 'dest_ip': '${SERVICE_IP}', 'dest_port': 8080, 'proto': 'tcp'}] + with patch.object(firewall_manager, '_iptables', side_effect=fake_iptables), \ + patch.object(firewall_manager, 'clear_service_rules'): + firewall_manager.apply_service_rules('app', '10.20.0.9', rules) + self.assertIn('10.20.0.9', captured) + + +class TestClearServiceRules(unittest.TestCase): + def test_no_matching_rules_skips_restore(self): + # iptables-save returns output with no matching tag + save_proc = _make_proc(returncode=0, stdout='*filter\n-A FORWARD -j ACCEPT\nCOMMIT\n') + with patch.object(firewall_manager, '_wg_exec', return_value=save_proc), \ + patch('subprocess.run') as mock_restore: + firewall_manager.clear_service_rules('nonexistent-svc') + mock_restore.assert_not_called() + + def test_exception_is_logged_not_raised(self): + with patch.object(firewall_manager, '_wg_exec', side_effect=RuntimeError('no docker')): + # Should not raise + firewall_manager.clear_service_rules('gitea') + + def test_save_failure_skips_restore(self): + save_proc = _make_proc(returncode=1, stderr='failed') + with patch.object(firewall_manager, '_wg_exec', return_value=save_proc), \ + patch('subprocess.run') as mock_restore: + firewall_manager.clear_service_rules('gitea') + mock_restore.assert_not_called() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_is_local_request_per_endpoint.py b/tests/test_is_local_request_per_endpoint.py index b8b0934..6e0a6cb 100644 --- a/tests/test_is_local_request_per_endpoint.py +++ b/tests/test_is_local_request_per_endpoint.py @@ -33,7 +33,6 @@ Tested local-only endpoints (representative sample): Tested public endpoints (no is_local_request guard): GET /api/calendar/status GET /api/dns/records - GET /api/dhcp/leases GET /api/cells """ @@ -216,12 +215,6 @@ class TestPublicEndpointsNotBlockedForNonLocal(unittest.TestCase): r = _get_non_local(self.client, '/api/dns/records') self.assertNotEqual(r.status_code, 403) - @patch('app.network_manager') - def test_dhcp_leases_not_403_for_non_local(self, mock_nm): - mock_nm.get_dhcp_leases.return_value = [] - r = _get_non_local(self.client, '/api/dhcp/leases') - self.assertNotEqual(r.status_code, 403) - @patch('app.cell_link_manager') def test_cells_list_not_403_for_non_local(self, mock_clm): mock_clm.list_connections.return_value = [] diff --git a/tests/test_log_manager_extra.py b/tests/test_log_manager_extra.py new file mode 100644 index 0000000..ee85333 --- /dev/null +++ b/tests/test_log_manager_extra.py @@ -0,0 +1,438 @@ +#!/usr/bin/env python3 +""" +Additional tests for LogManager covering uncovered paths: +- get_service_logs_parsed (JSON + non-JSON lines, level filtering) +- search_logs (time_range filter, level filter, non-JSON lines) +- _matches_search_criteria (query/time_range/level checks) +- _is_log_level (JSON and text fallback) +- export_logs (json/csv/text/unknown format) +- _logs_to_csv / _logs_to_text +- get_log_statistics (single service and all services) +- get_log_file_info (missing file → error key) +- set_service_level (known and unknown service) +- get_service_levels +- get_all_log_file_infos (active + rotated files) +- compress_old_logs +- stop +- _start_rotation_monitor creates a running thread +""" + +import sys +import os +import json +import gzip +import shutil +import tempfile +import time +import unittest +from datetime import datetime, timedelta +from pathlib import Path + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from log_manager import LogManager + + +def _lm(tmp): + log_dir = os.path.join(tmp, 'logs') + os.makedirs(log_dir, exist_ok=True) + return LogManager(log_dir=log_dir) + + +def _write_log_file(log_dir, service, entries): + """Write JSON log entries to a service log file.""" + path = os.path.join(log_dir, f'{service}.log') + with open(path, 'w') as f: + for e in entries: + f.write(json.dumps(e) + '\n') + return path + + +class TestGetServiceLogsParsed(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_missing_log_file_returns_error_entry(self): + result = self.lm.get_service_logs_parsed('nosuchservice') + self.assertEqual(len(result), 1) + self.assertIn('error', result[0]) + + def test_returns_parsed_json_entries(self): + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': '2026-01-01T00:00:00', 'level': 'INFO', 'message': 'ok'}, + {'timestamp': '2026-01-01T00:00:01', 'level': 'ERROR', 'message': 'fail'}, + ]) + result = self.lm.get_service_logs_parsed('svc', level='ALL', lines=10) + self.assertEqual(len(result), 2) + levels = {e['level'] for e in result} + self.assertIn('INFO', levels) + self.assertIn('ERROR', levels) + + def test_level_filter_excludes_non_matching(self): + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': '2026-01-01T00:00:00', 'level': 'INFO', 'message': 'ok'}, + {'timestamp': '2026-01-01T00:00:01', 'level': 'ERROR', 'message': 'fail'}, + ]) + result = self.lm.get_service_logs_parsed('svc', level='ERROR', lines=10) + self.assertTrue(all(e.get('level') == 'ERROR' for e in result)) + + def test_non_json_lines_are_included_for_all_level(self): + path = os.path.join(self.lm.log_dir, 'svc.log') + with open(path, 'w') as f: + f.write('plain text log line\n') + result = self.lm.get_service_logs_parsed('svc', level='ALL', lines=10) + self.assertEqual(len(result), 1) + self.assertIn('raw_line', result[0]) + + +class TestSearchLogs(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_search_finds_matching_message(self): + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': '2026-01-01T00:00:00', 'level': 'INFO', 'message': 'user login ok'}, + {'timestamp': '2026-01-01T00:00:01', 'level': 'INFO', 'message': 'disk full'}, + ]) + self.lm.service_loggers['svc'] = type('', (), {})() # register so search includes it + results = self.lm.search_logs('login', services=['svc']) + self.assertEqual(len(results), 1) + self.assertIn('login', results[0]['message']) + + def test_search_level_filter(self): + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': '2026-01-01T00:00:00', 'level': 'INFO', 'message': 'info msg'}, + {'timestamp': '2026-01-01T00:00:01', 'level': 'ERROR', 'message': 'error msg'}, + ]) + results = self.lm.search_logs('', services=['svc'], level='ERROR') + self.assertTrue(all(e.get('level') == 'ERROR' for e in results)) + + def test_search_time_range_filter(self): + early = '2026-01-01T10:00:00' + late = '2026-01-01T14:00:00' + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': early, 'level': 'INFO', 'message': 'early'}, + {'timestamp': late, 'level': 'INFO', 'message': 'late'}, + ]) + # Register logger so search_logs finds the service + self.lm.add_service_logger('svc', {'level': 'INFO'}) + start = datetime(2026, 1, 1, 11, 0, 0) + end = datetime(2026, 1, 1, 13, 0, 0) + results = self.lm.search_logs('', services=['svc'], time_range=(start, end)) + msgs = [r.get('message', '') for r in results] + # 'early' is within 11:00-13:00 range, 'late' at 14:00 should be excluded + self.assertNotIn('late', msgs) + + def test_non_json_lines_matched_by_query(self): + path = os.path.join(self.lm.log_dir, 'svc.log') + with open(path, 'w') as f: + f.write('this line contains keyterm\n') + f.write('unrelated line\n') + results = self.lm.search_logs('keyterm', services=['svc']) + self.assertEqual(len(results), 1) + self.assertIn('keyterm', results[0]['raw_line']) + + def test_search_no_services_returns_empty(self): + results = self.lm.search_logs('anything', services=[]) + self.assertEqual(results, []) + + +class TestMatchesSearchCriteria(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_query_match(self): + entry = {'message': 'hello world', 'level': 'INFO'} + self.assertTrue(self.lm._matches_search_criteria(entry, 'hello', None, None)) + + def test_query_no_match(self): + entry = {'message': 'hello world', 'level': 'INFO'} + self.assertFalse(self.lm._matches_search_criteria(entry, 'nothere', None, None)) + + def test_level_filter_match(self): + entry = {'message': 'msg', 'level': 'ERROR'} + self.assertTrue(self.lm._matches_search_criteria(entry, '', None, 'ERROR')) + + def test_level_filter_no_match(self): + entry = {'message': 'msg', 'level': 'INFO'} + self.assertFalse(self.lm._matches_search_criteria(entry, '', None, 'ERROR')) + + def test_time_range_within(self): + entry = {'message': 'msg', 'level': 'INFO', 'timestamp': '2026-06-01T12:00:00'} + start = datetime(2026, 6, 1, 11, 0) + end = datetime(2026, 6, 1, 13, 0) + self.assertTrue(self.lm._matches_search_criteria(entry, '', (start, end), None)) + + def test_time_range_outside(self): + entry = {'message': 'msg', 'level': 'INFO', 'timestamp': '2026-06-01T14:00:00'} + start = datetime(2026, 6, 1, 11, 0) + end = datetime(2026, 6, 1, 13, 0) + self.assertFalse(self.lm._matches_search_criteria(entry, '', (start, end), None)) + + def test_time_range_invalid_timestamp_excluded(self): + entry = {'message': 'msg', 'level': 'INFO', 'timestamp': 'not-a-date'} + start = datetime(2026, 6, 1, 11, 0) + end = datetime(2026, 6, 1, 13, 0) + self.assertFalse(self.lm._matches_search_criteria(entry, '', (start, end), None)) + + +class TestIsLogLevel(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_json_line_matches_level(self): + line = json.dumps({'level': 'ERROR', 'message': 'oops'}) + self.assertTrue(self.lm._is_log_level(line, 'ERROR')) + + def test_json_line_no_match(self): + line = json.dumps({'level': 'INFO', 'message': 'ok'}) + self.assertFalse(self.lm._is_log_level(line, 'ERROR')) + + def test_text_line_fallback_match(self): + line = '2026-01-01 ERROR something went wrong' + self.assertTrue(self.lm._is_log_level(line, 'ERROR')) + + def test_text_line_fallback_no_match(self): + line = '2026-01-01 INFO something went fine' + self.assertFalse(self.lm._is_log_level(line, 'ERROR')) + + +class TestExportLogs(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_export_json_format(self): + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': '2026-01-01T00:00:00', 'level': 'INFO', 'message': 'hi'}, + ]) + result = self.lm.export_logs('json', filters={'services': ['svc']}) + parsed = json.loads(result) + self.assertIsInstance(parsed, list) + + def test_export_csv_format(self): + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': '2026-01-01T00:00:00', 'level': 'INFO', 'message': 'hi'}, + ]) + result = self.lm.export_logs('csv', filters={'services': ['svc']}) + self.assertIsInstance(result, str) + + def test_export_text_format(self): + _write_log_file(self.lm.log_dir, 'svc', [ + {'timestamp': '2026-01-01T00:00:00', 'level': 'INFO', 'message': 'hi'}, + ]) + result = self.lm.export_logs('text', filters={'services': ['svc']}) + self.assertIsInstance(result, str) + + def test_export_unknown_format_raises(self): + with self.assertRaises(ValueError): + self.lm.export_logs('pdf') + + def test_logs_to_csv_empty_returns_empty_string(self): + result = self.lm._logs_to_csv([]) + self.assertEqual(result, '') + + def test_logs_to_csv_has_header(self): + logs = [{'level': 'INFO', 'message': 'hi', 'timestamp': '2026-01-01T00:00:00'}] + csv = self.lm._logs_to_csv(logs) + lines = csv.split('\n') + self.assertGreater(len(lines), 1) + header_fields = set(lines[0].split(',')) + self.assertIn('level', header_fields) + self.assertIn('message', header_fields) + + def test_logs_to_text_formats_entries(self): + logs = [{'level': 'ERROR', 'service': 'svc', 'message': 'oops', 'timestamp': '2026-01-01T00:00:00'}] + text = self.lm._logs_to_text(logs) + self.assertIn('ERROR', text) + self.assertIn('oops', text) + + +class TestGetLogStatistics(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_stats_for_missing_service_log(self): + self.lm.add_service_logger('svc', {'level': 'INFO'}) + stats = self.lm.get_log_statistics('svc') + # Log file may not have been written yet; should still return dict + self.assertIsInstance(stats, dict) + + def test_stats_counts_levels(self): + self.lm.add_service_logger('svc', {'level': 'INFO'}) + lgr = self.lm.service_loggers['svc'] + lgr.info('info msg') + lgr.error('error msg') + # Flush handlers + for h in lgr.handlers: + h.flush() + stats = self.lm.get_log_statistics('svc') + self.assertIn('svc', stats) + + def test_stats_for_nonexistent_service_returns_error_key(self): + # A service that was never added has no log file + stats = self.lm.get_log_statistics('nosuch') + self.assertIn('nosuch', stats) + self.assertIn('error', stats['nosuch']) + + +class TestGetLogFileInfo(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_returns_error_for_missing_log(self): + info = self.lm.get_log_file_info('nosuchservice') + self.assertIn('error', info) + + def test_returns_file_info_for_existing_log(self): + self.lm.add_service_logger('svc', {'level': 'INFO'}) + lgr = self.lm.service_loggers['svc'] + lgr.info('test') + for h in lgr.handlers: + h.flush() + info = self.lm.get_log_file_info('svc') + self.assertIn('file_path', info) + self.assertIn('exists', info) + self.assertTrue(info['exists']) + + +class TestSetServiceLevel(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_set_level_for_known_service(self): + import logging + self.lm.add_service_logger('svc', {'level': 'INFO'}) + self.lm.set_service_level('svc', 'DEBUG') + lgr = self.lm.service_loggers['svc'] + self.assertEqual(lgr.level, logging.DEBUG) + + def test_set_level_for_unknown_service_does_not_raise(self): + self.lm.set_service_level('nosuch', 'DEBUG') # must not raise + + def test_get_service_levels_returns_dict(self): + self.lm.add_service_logger('svc', {'level': 'INFO'}) + levels = self.lm.get_service_levels() + self.assertIsInstance(levels, dict) + self.assertIn('svc', levels) + + +class TestGetAllLogFileInfos(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_returns_list(self): + result = self.lm.get_all_log_file_infos() + self.assertIsInstance(result, list) + + def test_includes_active_log_file(self): + self.lm.add_service_logger('svc', {'level': 'INFO'}) + lgr = self.lm.service_loggers['svc'] + lgr.info('test') + for h in lgr.handlers: + h.flush() + result = self.lm.get_all_log_file_infos() + names = [e['file'] for e in result] + self.assertIn('svc.log', names) + + def test_includes_rotated_log_file(self): + # Create a fake rotated log file + rotated = os.path.join(str(self.lm.log_dir), 'svc.log.1') + with open(rotated, 'w') as f: + f.write('old log\n') + result = self.lm.get_all_log_file_infos() + names = [e['file'] for e in result] + self.assertIn('svc.log.1', names) + + +class TestCompressOldLogs(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + self.lm.stop() + shutil.rmtree(self.tmp) + + def test_compress_old_logs_compresses_rotated_file(self): + rotated = os.path.join(str(self.lm.log_dir), 'svc.log.1') + with open(rotated, 'w') as f: + f.write('old log content\n') + self.lm.compress_old_logs() + gz_path = rotated + '.gz' + self.assertTrue(os.path.exists(gz_path)) + self.assertFalse(os.path.exists(rotated)) + + def test_compress_old_logs_skips_already_compressed(self): + gz_path = os.path.join(str(self.lm.log_dir), 'svc.log.1.gz') + with gzip.open(gz_path, 'wb') as f: + f.write(b'already compressed') + self.lm.compress_old_logs() + # Original gz should still exist + self.assertTrue(os.path.exists(gz_path)) + + +class TestStop(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.lm = _lm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_stop_sets_running_false(self): + self.lm.stop() + self.assertFalse(self.lm.running) + + def test_stop_closes_all_handlers(self): + self.lm.add_service_logger('svc', {'level': 'INFO'}) + self.lm.stop() # must not raise + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_network_endpoints.py b/tests/test_network_endpoints.py index b5b1c14..0b3cdf4 100644 --- a/tests/test_network_endpoints.py +++ b/tests/test_network_endpoints.py @@ -1,15 +1,13 @@ #!/usr/bin/env python3 """ -Unit tests for network/DNS/DHCP Flask endpoints in api/app.py. +Unit tests for network/DNS Flask endpoints in api/app.py. Covers: GET /api/dns/records POST /api/dns/records DELETE /api/dns/records GET /api/dns/status - GET /api/dhcp/leases - POST /api/dhcp/reservations - DELETE /api/dhcp/reservations + GET /api/dns/overview GET /api/network/info POST /api/network/test """ @@ -150,149 +148,40 @@ class TestGetDnsStatus(unittest.TestCase): self.assertIn('error', json.loads(r.data)) -class TestGetDhcpLeases(unittest.TestCase): +class TestGetDnsOverview(unittest.TestCase): def setUp(self): app.config['TESTING'] = True self.client = app.test_client() + @patch('app.ddns_manager') + @patch('app.config_manager') @patch('app.network_manager') - def test_get_dhcp_leases_returns_200_with_list(self, mock_nm): - mock_nm.get_dhcp_leases.return_value = [ - {'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.101', 'hostname': 'laptop'}, - ] - r = self.client.get('/api/dhcp/leases') + def test_get_dns_overview_returns_200(self, mock_nm, mock_cm, mock_dm): + mock_nm.get_dns_overview.return_value = { + 'mode': 'pic_ngo', + 'provider': 'pic_ngo', + 'effective_domain': 'mycell.pic.ngo', + 'internal_domain': 'cell', + 'public_ip': '1.2.3.4', + 'public_records': [], + 'internal_records': [], + 'service_subdomains': [], + 'registration_status': {'registered': True}, + } + r = self.client.get('/api/dns/overview') self.assertEqual(r.status_code, 200) data = json.loads(r.data) - self.assertIsInstance(data, list) - self.assertEqual(len(data), 1) - self.assertEqual(data[0]['hostname'], 'laptop') + self.assertEqual(data['mode'], 'pic_ngo') + self.assertEqual(data['effective_domain'], 'mycell.pic.ngo') + mock_nm.get_dns_overview.assert_called_once_with(mock_cm, mock_dm) + @patch('app.ddns_manager') + @patch('app.config_manager') @patch('app.network_manager') - def test_get_dhcp_leases_returns_empty_list_when_no_leases(self, mock_nm): - mock_nm.get_dhcp_leases.return_value = [] - r = self.client.get('/api/dhcp/leases') - self.assertEqual(r.status_code, 200) - self.assertEqual(json.loads(r.data), []) - - @patch('app.network_manager') - def test_get_dhcp_leases_returns_500_on_exception(self, mock_nm): - mock_nm.get_dhcp_leases.side_effect = Exception('dnsmasq not running') - r = self.client.get('/api/dhcp/leases') - self.assertEqual(r.status_code, 500) - self.assertIn('error', json.loads(r.data)) - - -class TestAddDhcpReservation(unittest.TestCase): - - def setUp(self): - app.config['TESTING'] = True - self.client = app.test_client() - - @patch('app.network_manager') - def test_add_reservation_returns_200_on_valid_body(self, mock_nm): - mock_nm.add_dhcp_reservation.return_value = True - r = self.client.post( - '/api/dhcp/reservations', - data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50', 'hostname': 'printer'}), - content_type='application/json', - ) - self.assertEqual(r.status_code, 200) - data = json.loads(r.data) - self.assertIn('success', data) - - @patch('app.network_manager') - def test_add_reservation_returns_400_when_no_body(self, mock_nm): - r = self.client.post('/api/dhcp/reservations') - self.assertEqual(r.status_code, 400) - self.assertIn('error', json.loads(r.data)) - mock_nm.add_dhcp_reservation.assert_not_called() - - @patch('app.network_manager') - def test_add_reservation_returns_400_when_mac_missing(self, mock_nm): - r = self.client.post( - '/api/dhcp/reservations', - data=json.dumps({'ip': '192.168.1.50'}), - content_type='application/json', - ) - self.assertEqual(r.status_code, 400) - self.assertIn('error', json.loads(r.data)) - - @patch('app.network_manager') - def test_add_reservation_returns_400_when_ip_missing(self, mock_nm): - r = self.client.post( - '/api/dhcp/reservations', - data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), - content_type='application/json', - ) - self.assertEqual(r.status_code, 400) - self.assertIn('error', json.loads(r.data)) - - @patch('app.network_manager') - def test_add_reservation_uses_empty_hostname_when_omitted(self, mock_nm): - mock_nm.add_dhcp_reservation.return_value = True - self.client.post( - '/api/dhcp/reservations', - data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50'}), - content_type='application/json', - ) - mock_nm.add_dhcp_reservation.assert_called_once_with('aa:bb:cc:dd:ee:ff', '192.168.1.50', '') - - @patch('app.network_manager') - def test_add_reservation_returns_500_on_exception(self, mock_nm): - mock_nm.add_dhcp_reservation.side_effect = Exception('dnsmasq config error') - r = self.client.post( - '/api/dhcp/reservations', - data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50'}), - content_type='application/json', - ) - self.assertEqual(r.status_code, 500) - self.assertIn('error', json.loads(r.data)) - - -class TestDeleteDhcpReservation(unittest.TestCase): - - def setUp(self): - app.config['TESTING'] = True - self.client = app.test_client() - - @patch('app.network_manager') - def test_delete_reservation_returns_200_on_success(self, mock_nm): - mock_nm.remove_dhcp_reservation.return_value = True - r = self.client.delete( - '/api/dhcp/reservations', - data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), - content_type='application/json', - ) - self.assertEqual(r.status_code, 200) - data = json.loads(r.data) - self.assertIn('success', data) - - @patch('app.network_manager') - def test_delete_reservation_returns_400_when_mac_missing(self, mock_nm): - r = self.client.delete( - '/api/dhcp/reservations', - data=json.dumps({'hostname': 'printer'}), - content_type='application/json', - ) - self.assertEqual(r.status_code, 400) - self.assertIn('error', json.loads(r.data)) - mock_nm.remove_dhcp_reservation.assert_not_called() - - @patch('app.network_manager') - def test_delete_reservation_returns_400_when_no_body(self, mock_nm): - r = self.client.delete('/api/dhcp/reservations') - self.assertEqual(r.status_code, 400) - self.assertIn('error', json.loads(r.data)) - - @patch('app.network_manager') - def test_delete_reservation_returns_500_on_exception(self, mock_nm): - mock_nm.remove_dhcp_reservation.side_effect = Exception('reservation not found') - r = self.client.delete( - '/api/dhcp/reservations', - data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), - content_type='application/json', - ) + def test_get_dns_overview_returns_500_on_exception(self, mock_nm, mock_cm, mock_dm): + mock_nm.get_dns_overview.side_effect = Exception('boom') + r = self.client.get('/api/dns/overview') self.assertEqual(r.status_code, 500) self.assertIn('error', json.loads(r.data)) diff --git a/tests/test_network_manager.py b/tests/test_network_manager.py index 0cf8e94..4ca7940 100644 --- a/tests/test_network_manager.py +++ b/tests/test_network_manager.py @@ -46,8 +46,7 @@ class TestNetworkManager(unittest.TestCase): self.assertEqual(self.network_manager.data_dir, self.data_dir) self.assertEqual(self.network_manager.config_dir, self.config_dir) self.assertTrue(os.path.exists(self.network_manager.dns_zones_dir)) - self.assertTrue(os.path.exists(os.path.dirname(self.network_manager.dhcp_leases_file))) - + def test_generate_zone_content(self): """Test DNS zone content generation""" records = [ @@ -124,57 +123,6 @@ test2 1800 IN CNAME test1 self.assertEqual(records[1]['name'], 'test2') self.assertEqual(records[1]['type'], 'CNAME') - def test_get_dhcp_leases(self): - """Test getting DHCP leases""" - # Create a test leases file - leases_file = self.network_manager.dhcp_leases_file - content = """1234567890 aa:bb:cc:dd:ee:ff 192.168.1.100 testhost * -1234567891 11:22:33:44:55:66 192.168.1.101 anotherhost * -""" - - with open(leases_file, 'w') as f: - f.write(content) - - leases = self.network_manager.get_dhcp_leases() - - self.assertEqual(len(leases), 2) - self.assertEqual(leases[0]['mac'], 'aa:bb:cc:dd:ee:ff') - self.assertEqual(leases[0]['ip'], '192.168.1.100') - self.assertEqual(leases[0]['hostname'], 'testhost') - self.assertEqual(leases[1]['mac'], '11:22:33:44:55:66') - self.assertEqual(leases[1]['ip'], '192.168.1.101') - - def test_add_dhcp_reservation(self): - """Test adding DHCP reservation""" - success = self.network_manager.add_dhcp_reservation('aa:bb:cc:dd:ee:ff', '192.168.1.100', 'testhost') - self.assertTrue(success) - - # Check if reservation file was created - reservation_file = os.path.join(self.config_dir, 'dhcp', 'reservations.conf') - self.assertTrue(os.path.exists(reservation_file)) - - # Check content - with open(reservation_file, 'r') as f: - content = f.read() - self.assertIn('aa:bb:cc:dd:ee:ff', content) - self.assertIn('192.168.1.100', content) - self.assertIn('testhost', content) - - def test_remove_dhcp_reservation(self): - """Test removing DHCP reservation""" - # Add a reservation first - self.network_manager.add_dhcp_reservation('aa:bb:cc:dd:ee:ff', '192.168.1.100', 'testhost') - - # Remove it - success = self.network_manager.remove_dhcp_reservation('aa:bb:cc:dd:ee:ff') - self.assertTrue(success) - - # Check if reservation was removed - reservation_file = os.path.join(self.config_dir, 'dhcp', 'reservations.conf') - with open(reservation_file, 'r') as f: - content = f.read() - self.assertNotIn('aa:bb:cc:dd:ee:ff', content) - @patch('subprocess.run') def test_get_ntp_status(self, mock_run): """Test getting NTP status""" @@ -216,19 +164,6 @@ test2 1800 IN CNAME test1 self.assertFalse(result['success']) self.assertIn('NXDOMAIN', result['error']) - @patch('subprocess.run') - def test_test_dhcp_functionality(self, mock_run): - """Test DHCP functionality testing""" - # Mock DHCP service running - mock_run.return_value.stdout = 'cell-dhcp\n' - mock_run.return_value.returncode = 0 - - result = self.network_manager.test_dhcp_functionality() - - self.assertTrue(result['running']) - self.assertIn('leases_count', result) - self.assertIn('leases', result) - @patch('subprocess.run') def test_test_ntp_functionality(self, mock_run): """Test NTP functionality testing""" diff --git a/tests/test_network_manager_extra.py b/tests/test_network_manager_extra.py new file mode 100644 index 0000000..4c515f7 --- /dev/null +++ b/tests/test_network_manager_extra.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +""" +Additional tests for NetworkManager covering apply_domain, apply_cell_name, +get_status, get_dns_status, get_network_info, apply_config, and +input-validation paths not covered by the main test file. +""" + +import sys +import os +import json +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock, call + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from network_manager import NetworkManager + + +def _make_nm(tmp): + data_dir = os.path.join(tmp, 'data') + config_dir = os.path.join(tmp, 'config') + os.makedirs(data_dir, exist_ok=True) + os.makedirs(config_dir, exist_ok=True) + return NetworkManager(data_dir, config_dir), data_dir, config_dir + + +class TestApplyDomain(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + @patch('subprocess.run') + def test_apply_domain_no_config_files_returns_empty_warnings(self, _mock): + """apply_domain with no corefile or zone files is graceful.""" + result = self.nm.apply_domain('newcell', reload=False) + self.assertIsInstance(result, dict) + self.assertIn('restarted', result) + self.assertIn('warnings', result) + + @patch('subprocess.run') + def test_apply_domain_renames_and_rewrites_zone_file(self, _mock): + dns_data = os.path.join(self.data_dir, 'dns') + os.makedirs(dns_data, exist_ok=True) + zone_content = """$TTL 3600 +@ IN SOA oldcell. admin.oldcell. ( + 2026010101 ; Serial + 3600 1800 1209600 3600 ) +@ IN NS oldcell. +api 3600 IN A 10.0.0.1 +""" + with open(os.path.join(dns_data, 'oldcell.zone'), 'w') as f: + f.write(zone_content) + result = self.nm.apply_domain('newcell', reload=False) + new_zone = os.path.join(dns_data, 'newcell.zone') + self.assertTrue(os.path.exists(new_zone)) + with open(new_zone) as f: + written = f.read() + self.assertIn('newcell.', written) + + @patch('subprocess.run') + def test_apply_domain_reloads_dns_when_reload_true(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='', stderr='') + result = self.nm.apply_domain('newcell', reload=True) + calls_str = str(mock_run.call_args_list) + self.assertIn('SIGUSR1', calls_str) + + @patch('subprocess.run') + def test_apply_domain_does_not_reload_when_reload_false(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='', stderr='') + result = self.nm.apply_domain('newcell', reload=False) + calls_str = str(mock_run.call_args_list) + self.assertNotIn('SIGUSR1', calls_str) + + +class TestApplyCellName(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + self.dns_data = os.path.join(self.data_dir, 'dns') + os.makedirs(self.dns_data, exist_ok=True) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _write_zone(self, name, content): + with open(os.path.join(self.dns_data, f'{name}.zone'), 'w') as f: + f.write(content) + + @patch('subprocess.run') + def test_apply_cell_name_updates_hostname_record(self, _mock): + self._write_zone('cell', ( + 'oldname 3600 IN A 10.0.0.1\n' + 'api 3600 IN A 10.0.0.1\n' + )) + result = self.nm.apply_cell_name('oldname', 'newname') + with open(os.path.join(self.dns_data, 'cell.zone')) as f: + content = f.read() + self.assertIn('newname', content) + self.assertNotIn('oldname', content) + + @patch('subprocess.run') + def test_apply_cell_name_empty_new_name_does_nothing(self, mock_run): + result = self.nm.apply_cell_name('oldname', '') + mock_run.assert_not_called() + + @patch('subprocess.run') + def test_apply_cell_name_already_correct_does_not_write(self, mock_run): + self._write_zone('cell', ( + 'mycel 3600 IN A 10.0.0.1\n' + 'api 3600 IN A 10.0.0.1\n' + )) + result = self.nm.apply_cell_name('mycel', 'mycel') + # No reload should happen since name already matches + calls_str = str(mock_run.call_args_list) + self.assertNotIn('SIGUSR1', calls_str) + + @patch('subprocess.run') + def test_apply_cell_name_detects_hostname_when_old_name_absent(self, _mock): + """If old_name not in zone, detect hostname by non-service A record.""" + self._write_zone('cell', ( + 'detectedhost 3600 IN A 10.0.0.1\n' + 'calendar 3600 IN A 10.0.0.1\n' + )) + result = self.nm.apply_cell_name('', 'newname') + with open(os.path.join(self.dns_data, 'cell.zone')) as f: + content = f.read() + self.assertIn('newname', content) + self.assertNotIn('detectedhost', content) + + @patch('subprocess.run') + def test_apply_cell_name_skips_multi_label_zones(self, _mock): + """Multi-label zones (e.g. pic2.pic.ngo) must not be modified.""" + self._write_zone('cell', 'oldname 3600 IN A 10.0.0.1\n') + self._write_zone('pic2.pic.ngo', 'api 3600 IN A 10.0.0.1\n') + result = self.nm.apply_cell_name('oldname', 'newname') + with open(os.path.join(self.dns_data, 'pic2.pic.ngo.zone')) as f: + multi = f.read() + # multi-label zone should be unchanged + self.assertNotIn('newname', multi) + + +class TestApplyConfig(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + @patch('subprocess.run') + def test_apply_config_empty_dict_returns_no_warnings(self, _mock): + result = self.nm.apply_config({}) + self.assertEqual(result['warnings'], []) + + @patch('subprocess.run') + def test_apply_config_ntp_servers_updates_file(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='', stderr='') + ntp_dir = os.path.join(self.config_dir, 'ntp') + os.makedirs(ntp_dir, exist_ok=True) + ntp_conf = os.path.join(ntp_dir, 'chrony.conf') + with open(ntp_conf, 'w') as f: + f.write('server 0.pool.ntp.org iburst\nmakestep 1.0 3\n') + result = self.nm.apply_config({'ntp_servers': ['1.1.1.1', '8.8.8.8']}) + with open(ntp_conf) as f: + content = f.read() + self.assertIn('server 1.1.1.1 iburst', content) + self.assertIn('server 8.8.8.8 iburst', content) + self.assertNotIn('0.pool.ntp.org', content) + + +class TestGetDnsStatus(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + @patch('subprocess.run') + def test_get_dns_status_returns_dict(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='cell-dns\n', stderr='') + result = self.nm.get_dns_status() + self.assertIn('running', result) + self.assertIn('records_count', result) + + @patch('subprocess.run') + def test_get_dns_status_running_true_when_container_up(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='cell-dns\n', stderr='') + result = self.nm.get_dns_status() + self.assertTrue(result['running']) + + @patch('subprocess.run') + def test_get_dns_status_counts_records(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='cell-dns\n', stderr='') + dns_data = os.path.join(self.data_dir, 'dns') + os.makedirs(dns_data, exist_ok=True) + with open(os.path.join(dns_data, 'cell.zone'), 'w') as f: + f.write('api 3600 IN A 10.0.0.1\nwebui 3600 IN A 10.0.0.1\n') + result = self.nm.get_dns_status() + self.assertEqual(result['records_count'], 2) + + @patch('subprocess.run') + def test_get_dns_status_not_running(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='', stderr='') + result = self.nm.get_dns_status() + self.assertFalse(result['running']) + + +class TestGetNetworkInfo(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + @patch('subprocess.run') + def test_get_network_info_returns_dict(self, mock_run): + mock_run.return_value = MagicMock(returncode=1, stdout='', stderr='') + result = self.nm.get_network_info() + self.assertIsInstance(result, dict) + + @patch('subprocess.run') + def test_get_network_info_includes_dns_servers(self, mock_run): + mock_run.return_value = MagicMock(returncode=1, stdout='', stderr='') + result = self.nm.get_network_info() + self.assertIn('dns_servers', result) + + @patch('subprocess.run') + def test_get_network_info_parses_interfaces_json(self, mock_run): + iface_json = '[{"ifindex":1,"ifname":"lo","addr_info":[]}]' + mock_run.return_value = MagicMock(returncode=0, stdout=iface_json, stderr='') + result = self.nm.get_network_info() + self.assertIn('interfaces', result) + self.assertEqual(result['interfaces'][0]['ifname'], 'lo') + + +class TestGetStatus(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + @patch('subprocess.run') + def test_get_status_returns_dict_with_required_keys(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='', stderr='') + result = self.nm.get_status() + self.assertIn('running', result) + self.assertIn('status', result) + + @patch('subprocess.run') + @patch.dict(os.environ, {'DOCKER_CONTAINER': 'false'}) + def test_get_status_non_docker_path(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout='', stderr='') + result = self.nm.get_status() + self.assertIn('dns_running', result) + self.assertIn('ntp_running', result) + + +class TestGetDnsRecords(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + @patch('subprocess.run') + def test_get_dns_records_returns_list(self, _mock): + result = self.nm.get_dns_records() + self.assertIsInstance(result, list) + + @patch('subprocess.run') + def test_get_dns_records_from_zone_file(self, _mock): + dns_data = os.path.join(self.data_dir, 'dns') + os.makedirs(dns_data, exist_ok=True) + with open(os.path.join(dns_data, 'cell.zone'), 'w') as f: + f.write('api 3600 IN A 10.0.0.1\nwebui 3600 IN A 10.0.0.1\n') + result = self.nm.get_dns_records() + self.assertEqual(len(result), 2) + names = [r['name'] for r in result] + self.assertIn('api', names) + self.assertIn('webui', names) + + +class TestInputValidation(unittest.TestCase): + """Test input validation guards in add_dns_record and update_dns_zone.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_add_dns_record_invalid_zone_returns_false(self): + result = self.nm.add_dns_record('../etc/passwd', 'test', 'A', '10.0.0.1') + self.assertFalse(result) + + def test_add_dns_record_invalid_name_returns_false(self): + result = self.nm.add_dns_record('cell', 'name with spaces', 'A', '10.0.0.1') + self.assertFalse(result) + + def test_add_dns_record_invalid_value_returns_false(self): + result = self.nm.add_dns_record('cell', 'test', 'A', '10.0.0.1; rm -rf /') + self.assertFalse(result) + + @patch('subprocess.run') + def test_update_dns_zone_invalid_record_name_returns_false(self, _mock): + records = [{'name': 'valid', 'type': 'A', 'value': '10.0.0.1'}, + {'name': 'bad\nname', 'type': 'A', 'value': '10.0.0.2'}] + result = self.nm.update_dns_zone('cell', records) + self.assertFalse(result) + + def test_update_dns_zone_invalid_zone_name_returns_false(self): + result = self.nm.update_dns_zone('', []) + self.assertFalse(result) + + +class TestGetWgServerIp(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.nm, self.data_dir, self.config_dir = _make_nm(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_returns_fallback_when_no_conf(self): + ip = self.nm._get_wg_server_ip() + self.assertEqual(ip, '10.0.0.1') + + def test_parses_wg_conf_address(self): + wg_dir = os.path.join(self.config_dir, 'wireguard', 'wg_confs') + os.makedirs(wg_dir, exist_ok=True) + with open(os.path.join(wg_dir, 'wg0.conf'), 'w') as f: + f.write('[Interface]\nAddress = 172.16.0.1/24\nPrivateKey = abc\n') + ip = self.nm._get_wg_server_ip() + self.assertEqual(ip, '172.16.0.1') + + +class TestGetDnsOverview(unittest.TestCase): + """get_dns_overview composes config_manager, registry, and ddns_manager + into a provider-aware structure without writing DNS.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.data_dir = os.path.join(self.tmp, 'data') + self.config_dir = os.path.join(self.tmp, 'config') + os.makedirs(os.path.join(self.data_dir, 'dns'), exist_ok=True) + self.registry = MagicMock() + self.registry.get_caddy_routes.return_value = [ + {'subdomain': 'mail', 'backend': 'cell-mail:80', + 'extra_subdomains': [], 'extra_backends': {}}, + ] + self.nm = NetworkManager(self.data_dir, self.config_dir, + service_registry=self.registry) + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _cm(self, identity, ddns=None, token=''): + cm = MagicMock() + cm.get_identity.return_value = identity + cm.configs = {'ddns': ddns or {}} + cm.get_ddns_token.return_value = token + mode = identity.get('domain_mode', 'lan') + if mode == 'lan': + cm.get_effective_domain.return_value = identity.get('domain', 'cell') + else: + cm.get_effective_domain.return_value = identity.get('domain_name', '') + cm.get_internal_domain.return_value = identity.get('domain', 'cell') + return cm + + def test_lan_mode_has_no_public_records(self): + cm = self._cm({'domain_mode': 'lan', 'domain': 'cell'}) + ov = self.nm.get_dns_overview(cm, ddns_manager=None) + self.assertEqual(ov['mode'], 'lan') + self.assertEqual(ov['public_records'], []) + self.assertEqual(ov['internal_domain'], 'cell') + self.assertIsNone(ov['public_ip']) + + def test_pic_ngo_mode_apex_and_wildcard(self): + cm = self._cm( + {'domain_mode': 'pic_ngo', 'domain': 'cell', 'domain_name': 'mycell.pic.ngo'}, + ddns={'provider': 'pic_ngo'}, token='tok') + ddns_mgr = MagicMock() + ddns_mgr.get_status.return_value = {'provider': 'pic_ngo'} + ov = self.nm.get_dns_overview(cm, ddns_manager=ddns_mgr, public_ip='1.2.3.4') + names = [r['name'] for r in ov['public_records']] + self.assertIn('mycell.pic.ngo', names) + self.assertIn('*.mycell.pic.ngo', names) + self.assertTrue(all(r['value'] == '1.2.3.4' for r in ov['public_records'])) + self.assertTrue(all(r['status'] == 'registered' for r in ov['public_records'])) + self.assertEqual(ov['service_subdomains'][0]['fqdn'], 'mail.mycell.pic.ngo') + + def test_cloudflare_mode_per_service_records(self): + cm = self._cm( + {'domain_mode': 'cloudflare', 'domain': 'cell', 'domain_name': 'cell.example.com'}, + ddns={'provider': 'cloudflare'}, token='cf') + ov = self.nm.get_dns_overview(cm, ddns_manager=None, public_ip='5.5.5.5') + names = [r['name'] for r in ov['public_records']] + self.assertIn('cell.example.com', names) + self.assertIn('mail.cell.example.com', names) + self.assertNotIn('*.cell.example.com', names) + + def test_custom_mode_per_service_records(self): + cm = self._cm( + {'domain_mode': 'custom', 'domain': 'cell', 'domain_name': 'cell.example.org'}, + ddns={'provider': 'custom'}) + ov = self.nm.get_dns_overview(cm, ddns_manager=None, public_ip='6.6.6.6') + names = [r['name'] for r in ov['public_records']] + self.assertIn('cell.example.org', names) + self.assertIn('mail.cell.example.org', names) + self.assertFalse(ov['registration_status']['registered']) + self.assertTrue(all(r['status'] == 'unregistered' for r in ov['public_records'])) + + def test_internal_records_come_from_zone_files(self): + with open(os.path.join(self.data_dir, 'dns', 'cell.zone'), 'w') as f: + f.write('api 3600 IN A 10.0.0.1\n') + cm = self._cm({'domain_mode': 'lan', 'domain': 'cell'}) + ov = self.nm.get_dns_overview(cm, ddns_manager=None) + self.assertEqual(len(ov['internal_records']), 1) + self.assertEqual(ov['internal_records'][0]['name'], 'api') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_peer_provisioning.py b/tests/test_peer_provisioning.py index 03ce6ac..99d87cb 100644 --- a/tests/test_peer_provisioning.py +++ b/tests/test_peer_provisioning.py @@ -171,18 +171,85 @@ def test_create_peer_returns_201(admin_client): def test_create_peer_provisions_all_services( - admin_client, auth_mgr, - mock_email_mgr, mock_calendar_mgr, mock_file_mgr): - """All four service create methods must be called exactly once.""" - _post_peer(admin_client) - # auth provisioning — check user was created in the real auth_mgr - # (we use the real auth_mgr so we can inspect the result directly) - alice = auth_mgr.get_user('alice') - assert alice is not None, 'auth_manager.create_user was not called for alice' + auth_mgr, mock_email_mgr, mock_calendar_mgr, + mock_file_mgr, mock_wg_mgr, mock_peer_registry): + """With email/calendar/files all installed, all four service create methods + must be called exactly once.""" + patches = _make_admin_client_with_installed( + auth_mgr, mock_email_mgr, mock_calendar_mgr, + mock_file_mgr, mock_wg_mgr, mock_peer_registry, + installed_services={'email': {}, 'calendar': {}, 'files': {}}, + ) + started = [p.start() for p in patches] + try: + with app.test_client() as client: + _login(client) + r = _post_peer(client) + assert r.status_code == 201, f'{r.status_code}: {r.data}' - mock_email_mgr.create_email_user.assert_called_once() - mock_calendar_mgr.create_calendar_user.assert_called_once() - mock_file_mgr.create_user.assert_called_once() + # auth provisioning — check user was created in the real auth_mgr + # (we use the real auth_mgr so we can inspect the result directly) + alice = auth_mgr.get_user('alice') + assert alice is not None, 'auth_manager.create_user was not called for alice' + + mock_email_mgr.create_email_user.assert_called_once() + mock_calendar_mgr.create_calendar_user.assert_called_once() + mock_file_mgr.create_user.assert_called_once() + finally: + for p in patches: + p.stop() + + +def test_create_peer_skips_builtin_provisioning_when_not_installed( + auth_mgr, mock_email_mgr, mock_calendar_mgr, + mock_file_mgr, mock_wg_mgr, mock_peer_registry): + """With no store services installed, email/calendar/files account creation + must NOT be attempted — those services do not exist on this cell.""" + patches = _make_admin_client_with_installed( + auth_mgr, mock_email_mgr, mock_calendar_mgr, + mock_file_mgr, mock_wg_mgr, mock_peer_registry, + installed_services={}, + ) + started = [p.start() for p in patches] + try: + with app.test_client() as client: + _login(client) + r = _post_peer(client) + assert r.status_code == 201, f'{r.status_code}: {r.data}' + + # Auth account is always created + assert auth_mgr.get_user('alice') is not None + + mock_email_mgr.create_email_user.assert_not_called() + mock_calendar_mgr.create_calendar_user.assert_not_called() + mock_file_mgr.create_user.assert_not_called() + finally: + for p in patches: + p.stop() + + +def test_create_peer_provisions_only_installed_subset( + auth_mgr, mock_email_mgr, mock_calendar_mgr, + mock_file_mgr, mock_wg_mgr, mock_peer_registry): + """With only calendar installed, only the calendar account is provisioned.""" + patches = _make_admin_client_with_installed( + auth_mgr, mock_email_mgr, mock_calendar_mgr, + mock_file_mgr, mock_wg_mgr, mock_peer_registry, + installed_services={'calendar': {}}, + ) + started = [p.start() for p in patches] + try: + with app.test_client() as client: + _login(client) + r = _post_peer(client) + assert r.status_code == 201, f'{r.status_code}: {r.data}' + + mock_calendar_mgr.create_calendar_user.assert_called_once() + mock_email_mgr.create_email_user.assert_not_called() + mock_file_mgr.create_user.assert_not_called() + finally: + for p in patches: + p.stop() def test_create_peer_response_has_ip(admin_client): @@ -231,8 +298,13 @@ def test_create_peer_email_failure_is_nonfatal( app.config['TESTING'] = True app.config['SECRET_KEY'] = 'test-secret' + # email must be installed for its provisioning step to run at all + mock_cfg = MagicMock() + mock_cfg.get_installed_services.return_value = {'email': {}} + patches = [ patch('app.auth_manager', auth_mgr), + patch('app.config_manager', mock_cfg), patch('app.email_manager', mock_email_mgr), patch('app.calendar_manager', mock_calendar_mgr), patch('app.file_manager', mock_file_mgr), diff --git a/tests/test_peer_registry.py b/tests/test_peer_registry.py index fa9771a..09d87ec 100644 --- a/tests/test_peer_registry.py +++ b/tests/test_peer_registry.py @@ -107,5 +107,123 @@ class TestPeerRegistry(unittest.TestCase): with self.assertRaises(ValueError): self.registry.set_route_via('nobody', 'exit-cell') + def test_set_peer_exit_via_valid(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + result = self.registry.set_peer_exit_via('alice', 'wireguard_ext') + self.assertTrue(result) + peer = self.registry.get_peer('alice') + self.assertEqual(peer['exit_via'], 'wireguard_ext') + + def test_set_peer_exit_via_all_valid_types(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + for exit_type in ('default', 'wireguard_ext', 'openvpn', 'tor'): + result = self.registry.set_peer_exit_via('alice', exit_type) + self.assertTrue(result) + peer = self.registry.get_peer('alice') + self.assertEqual(peer['exit_via'], exit_type) + + def test_set_peer_exit_via_invalid_type_returns_false(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + result = self.registry.set_peer_exit_via('alice', 'invalid_exit') + self.assertFalse(result) + + def test_set_peer_exit_via_nonexistent_peer_returns_false(self): + result = self.registry.set_peer_exit_via('nobody', 'default') + self.assertFalse(result) + + def test_set_peer_exit_via_persists(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + self.registry.set_peer_exit_via('alice', 'tor') + reloaded = PeerRegistry(data_dir=self.test_dir, config_dir=self.test_dir) + self.assertEqual(reloaded.get_peer('alice')['exit_via'], 'tor') + + def test_update_peer_updates_arbitrary_fields(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + result = self.registry.update_peer('alice', {'custom_field': 'hello', 'ip': '10.0.0.99'}) + self.assertTrue(result) + peer = self.registry.get_peer('alice') + self.assertEqual(peer['custom_field'], 'hello') + self.assertEqual(peer['ip'], '10.0.0.99') + + def test_update_peer_nonexistent_returns_false(self): + result = self.registry.update_peer('nobody', {'ip': '10.0.0.99'}) + self.assertFalse(result) + + def test_clear_reinstall_flag(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5', + 'config_needs_reinstall': True}) + result = self.registry.clear_reinstall_flag('alice') + self.assertTrue(result) + peer = self.registry.get_peer('alice') + self.assertFalse(peer['config_needs_reinstall']) + + def test_get_peer_stats_empty(self): + stats = self.registry.get_peer_stats() + self.assertEqual(stats['total_peers'], 0) + self.assertEqual(stats['active_peers'], 0) + self.assertEqual(stats['inactive_peers'], 0) + self.assertEqual(stats['ip_ranges'], {}) + + def test_get_peer_stats_with_peers(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.2', 'active': True}) + self.registry.add_peer({'peer': 'bob', 'ip': '10.0.0.3', 'active': False}) + stats = self.registry.get_peer_stats() + self.assertEqual(stats['total_peers'], 2) + self.assertEqual(stats['active_peers'], 1) + self.assertEqual(stats['inactive_peers'], 1) + self.assertIn('10.0.0.0/24', stats['ip_ranges']) + self.assertEqual(stats['ip_ranges']['10.0.0.0/24'], 2) + + def test_get_status_returns_correct_counts(self): + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.2', 'active': True}) + self.registry.add_peer({'peer': 'bob', 'ip': '10.0.0.3', 'active': False}) + status = self.registry.get_status() + self.assertEqual(status['peers_count'], 2) + self.assertEqual(status['active_peers'], 1) + self.assertEqual(status['inactive_peers'], 1) + self.assertTrue(status['running']) + + def test_test_connectivity_returns_dict(self): + result = self.registry.test_connectivity() + self.assertIn('filesystem_access', result) + self.assertIn('data_integrity', result) + self.assertIn('peer_operations', result) + self.assertIn('success', result) + + def test_test_connectivity_success(self): + result = self.registry.test_connectivity() + # All subtests should succeed since data dir exists + self.assertTrue(result['filesystem_access']['success']) + self.assertTrue(result['data_integrity']['success']) + + def test_list_peers_returns_copy(self): + """Modifying the returned list shouldn't affect internal state.""" + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + peers = self.registry.list_peers() + peers.clear() + self.assertEqual(len(self.registry.list_peers()), 1) + + def test_exit_via_migration_adds_field(self): + """Existing peers without exit_via get it as 'default' on load.""" + import json as _json + peers_file = os.path.join(self.test_dir, 'peers.json') + raw = [{'peer': 'alice', 'ip': '10.0.0.5', 'public_key': 'key=', + 'active': True, 'route_via': None, 'created_at': '2026-01-01T00:00:00'}] + with open(peers_file, 'w') as f: + _json.dump(raw, f) + reg = PeerRegistry(data_dir=self.test_dir, config_dir=self.test_dir) + peer = reg.get_peer('alice') + self.assertIn('exit_via', peer) + self.assertEqual(peer['exit_via'], 'default') + + def test_save_peers_uses_restrictive_permissions(self): + """peers.json should be created with mode 0o600.""" + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + import stat + mode = os.stat(self.registry.peers_file).st_mode + perms = stat.S_IMODE(mode) + self.assertEqual(perms, 0o600) + + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() \ No newline at end of file diff --git a/tests/test_routes_containers.py b/tests/test_routes_containers.py new file mode 100644 index 0000000..a1a36c3 --- /dev/null +++ b/tests/test_routes_containers.py @@ -0,0 +1,377 @@ +""" +Tests for routes/containers.py — container, image, and volume management endpoints. +All endpoints require is_local_request() to return True; non-local gets 403. +""" +import sys +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent / 'api')) + +import app as app_module +from app import app + + +@pytest.fixture +def local_client(): + """Flask test client that appears as local (loopback) request.""" + app.config['TESTING'] = True + with app.test_client() as c: + # is_local_request is imported inside each route handler from app + with patch.object(app_module, 'is_local_request', return_value=True): + yield c + + +@pytest.fixture +def remote_client(): + """Flask test client that appears as a non-local (remote) request.""" + app.config['TESTING'] = True + with app.test_client() as c: + with patch.object(app_module, 'is_local_request', return_value=False): + yield c + + +# --------------------------------------------------------------------------- +# GET /api/containers — requires local +# --------------------------------------------------------------------------- + +class TestListContainers: + def test_returns_200_from_local(self, local_client): + mock_cm = MagicMock() + mock_cm.list_containers.return_value = [{'name': 'cell-dns', 'status': 'running'}] + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/containers') + assert resp.status_code == 200 + + def test_returns_containers_list(self, local_client): + mock_cm = MagicMock() + mock_cm.list_containers.return_value = [{'name': 'cell-dns'}] + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/containers') + data = json.loads(resp.data) + assert isinstance(data, list) + assert data[0]['name'] == 'cell-dns' + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.get('/api/containers') + assert resp.status_code == 403 + + def test_500_on_exception(self, local_client): + mock_cm = MagicMock() + mock_cm.list_containers.side_effect = Exception('docker error') + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/containers') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/containers//start +# --------------------------------------------------------------------------- + +class TestStartContainer: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.start_container.return_value = True + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers/cell-dns/start') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.post('/api/containers/cell-dns/start') + assert resp.status_code == 403 + + def test_response_shape(self, local_client): + mock_cm = MagicMock() + mock_cm.start_container.return_value = True + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers/cell-dns/start') + data = json.loads(resp.data) + assert 'started' in data + + +# --------------------------------------------------------------------------- +# POST /api/containers//stop +# --------------------------------------------------------------------------- + +class TestStopContainer: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.stop_container.return_value = True + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers/cell-dns/stop') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.post('/api/containers/cell-dns/stop') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# POST /api/containers//restart +# --------------------------------------------------------------------------- + +class TestRestartContainer: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.restart_container.return_value = True + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers/cell-dns/restart') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.post('/api/containers/cell-dns/restart') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# GET /api/containers//logs +# --------------------------------------------------------------------------- + +class TestGetContainerLogs: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.get_container_logs.return_value = ['line1', 'line2'] + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/containers/cell-dns/logs') + assert resp.status_code == 200 + + def test_returns_logs_in_response(self, local_client): + mock_cm = MagicMock() + mock_cm.get_container_logs.return_value = ['line1'] + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/containers/cell-dns/logs?tail=50') + data = json.loads(resp.data) + assert 'logs' in data + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.get('/api/containers/cell-dns/logs') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# GET /api/containers//stats +# --------------------------------------------------------------------------- + +class TestGetContainerStats: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.get_container_stats.return_value = {'cpu': '5%', 'memory': '100MB'} + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/containers/cell-dns/stats') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.get('/api/containers/cell-dns/stats') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# POST /api/containers (create) +# --------------------------------------------------------------------------- + +class TestCreateContainer: + def test_returns_400_when_image_missing(self, local_client): + resp = local_client.post('/api/containers', + data=json.dumps({'name': 'test'}), + content_type='application/json') + assert resp.status_code == 400 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.post('/api/containers', + data=json.dumps({'image': 'nginx'}), + content_type='application/json') + assert resp.status_code == 403 + + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.create_container.return_value = {'id': 'abc123', 'name': 'mycontainer'} + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers', + data=json.dumps({'image': 'nginx', 'name': 'mycontainer'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_returns_500_when_result_has_error(self, local_client): + mock_cm = MagicMock() + mock_cm.create_container.return_value = {'error': 'image not found'} + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers', + data=json.dumps({'image': 'badimage'}), + content_type='application/json') + assert resp.status_code == 500 + + def test_volume_outside_allowed_path_returns_403(self, local_client): + """Volume mounts outside the allowed directories are blocked.""" + mock_cm = MagicMock() + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers', + data=json.dumps({ + 'image': 'nginx', + 'volumes': {'/etc/passwd': '/mnt/passwd'} + }), + content_type='application/json') + assert resp.status_code == 403 + + def test_volume_in_allowed_path_passes(self, local_client): + """Volume mounts under /tmp/ are permitted.""" + mock_cm = MagicMock() + mock_cm.create_container.return_value = {'id': 'abc'} + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/containers', + data=json.dumps({ + 'image': 'nginx', + 'volumes': {'/tmp/test': '/mnt/test'} + }), + content_type='application/json') + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# DELETE /api/containers/ +# --------------------------------------------------------------------------- + +class TestRemoveContainer: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.remove_container.return_value = True + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.delete('/api/containers/mycontainer') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.delete('/api/containers/mycontainer') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# GET /api/images +# --------------------------------------------------------------------------- + +class TestListImages: + def test_returns_200(self, local_client): + mock_cm = MagicMock() + mock_cm.list_images.return_value = [{'tag': 'nginx:latest'}] + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/images') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.get('/api/images') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# POST /api/images/pull +# --------------------------------------------------------------------------- + +class TestPullImage: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.pull_image.return_value = {'status': 'pulled'} + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/images/pull', + data=json.dumps({'image': 'nginx:latest'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_returns_400_when_image_missing(self, local_client): + resp = local_client.post('/api/images/pull', + data=json.dumps({}), + content_type='application/json') + assert resp.status_code == 400 + + def test_returns_500_when_pull_fails(self, local_client): + mock_cm = MagicMock() + mock_cm.pull_image.return_value = {'error': 'not found'} + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/images/pull', + data=json.dumps({'image': 'badimage'}), + content_type='application/json') + assert resp.status_code == 500 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.post('/api/images/pull', + data=json.dumps({'image': 'nginx'}), + content_type='application/json') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# DELETE /api/images/ +# --------------------------------------------------------------------------- + +class TestRemoveImage: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.remove_image.return_value = True + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.delete('/api/images/nginx:latest') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.delete('/api/images/nginx:latest') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# GET /api/volumes +# --------------------------------------------------------------------------- + +class TestListVolumes: + def test_returns_200(self, local_client): + mock_cm = MagicMock() + mock_cm.list_volumes.return_value = [] + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.get('/api/volumes') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.get('/api/volumes') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# POST /api/volumes (create) +# --------------------------------------------------------------------------- + +class TestCreateVolume: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.create_volume.return_value = {'name': 'myvolume'} + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.post('/api/volumes', + data=json.dumps({'name': 'myvolume'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_returns_400_when_name_missing(self, local_client): + resp = local_client.post('/api/volumes', + data=json.dumps({}), + content_type='application/json') + assert resp.status_code == 400 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.post('/api/volumes', + data=json.dumps({'name': 'v'}), + content_type='application/json') + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# DELETE /api/volumes/ +# --------------------------------------------------------------------------- + +class TestRemoveVolume: + def test_returns_200_on_success(self, local_client): + mock_cm = MagicMock() + mock_cm.remove_volume.return_value = True + with patch.object(app_module, 'container_manager', mock_cm): + resp = local_client.delete('/api/volumes/myvolume') + assert resp.status_code == 200 + + def test_returns_403_from_remote(self, remote_client): + resp = remote_client.delete('/api/volumes/myvolume') + assert resp.status_code == 403 diff --git a/tests/test_routes_service_store.py b/tests/test_routes_service_store.py new file mode 100644 index 0000000..ea63052 --- /dev/null +++ b/tests/test_routes_service_store.py @@ -0,0 +1,268 @@ +""" +Tests for routes/service_store.py: +- GET /api/store/services +- GET /api/store/services//manifest +- POST /api/store/services//install +- DELETE /api/store/services/ +- GET /api/store/installed +- POST /api/store/refresh +""" +import sys +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent / 'api')) + +import app as app_module +from app import app + + +@pytest.fixture +def client(): + app.config['TESTING'] = True + with app.test_client() as c: + yield c + + +# --------------------------------------------------------------------------- +# GET /api/store/services +# --------------------------------------------------------------------------- + +class TestListStoreServices: + def test_returns_200(self, client): + mock_ssm = MagicMock() + mock_ssm.list_services.return_value = {'available': [], 'installed': []} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.get('/api/store/services') + assert resp.status_code == 200 + + def test_returns_service_index(self, client): + mock_ssm = MagicMock() + mock_ssm.list_services.return_value = { + 'available': [{'id': 'nextcloud', 'name': 'Nextcloud'}], + 'installed': [] + } + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.get('/api/store/services') + data = json.loads(resp.data) + assert 'available' in data + assert len(data['available']) == 1 + + def test_500_on_exception(self, client): + mock_ssm = MagicMock() + mock_ssm.list_services.side_effect = Exception('network error') + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.get('/api/store/services') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/store/services//manifest +# --------------------------------------------------------------------------- + +class TestGetManifest: + def test_returns_200_on_success(self, client): + import requests as _requests + mock_resp = MagicMock() + mock_resp.json.return_value = {'id': 'nextcloud', 'version': '1.0'} + mock_resp.raise_for_status.return_value = None + with patch('routes.service_store._requests') as mock_req: + mock_req.get.return_value = mock_resp + mock_req.HTTPError = _requests.HTTPError + resp = client.get('/api/store/services/nextcloud/manifest') + assert resp.status_code == 200 + + def test_returns_manifest_data(self, client): + import requests as _requests + mock_resp = MagicMock() + mock_resp.json.return_value = {'id': 'nextcloud', 'version': '1.0', 'name': 'Nextcloud'} + mock_resp.raise_for_status.return_value = None + with patch('routes.service_store._requests') as mock_req: + mock_req.get.return_value = mock_resp + mock_req.HTTPError = _requests.HTTPError + resp = client.get('/api/store/services/nextcloud/manifest') + data = json.loads(resp.data) + assert data['id'] == 'nextcloud' + + def test_returns_404_on_http_error(self, client): + import requests as _requests + with patch('routes.service_store._requests') as mock_req: + mock_req.HTTPError = _requests.HTTPError + mock_req.get.return_value = MagicMock( + raise_for_status=MagicMock( + side_effect=_requests.HTTPError('404 Not Found'))) + resp = client.get('/api/store/services/unknown/manifest') + assert resp.status_code == 404 + + def test_500_on_network_error(self, client): + import requests as _requests + with patch('routes.service_store._requests') as mock_req: + mock_req.HTTPError = _requests.HTTPError + mock_req.get.side_effect = Exception('network timeout') + resp = client.get('/api/store/services/nextcloud/manifest') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/store/services//install +# --------------------------------------------------------------------------- + +class TestInstallService: + def test_returns_200_on_success(self, client): + mock_ssm = MagicMock() + mock_ssm.install.return_value = {'ok': True, 'message': 'Installed'} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.post('/api/store/services/nextcloud/install', + content_type='application/json') + assert resp.status_code == 200 + + def test_returns_install_result(self, client): + mock_ssm = MagicMock() + mock_ssm.install.return_value = {'ok': True, 'message': 'Installed'} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.post('/api/store/services/nextcloud/install', + content_type='application/json') + data = json.loads(resp.data) + assert data['ok'] is True + + def test_returns_400_on_failure(self, client): + mock_ssm = MagicMock() + mock_ssm.install.return_value = {'ok': False, 'error': 'Manifest not found'} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.post('/api/store/services/nextcloud/install', + content_type='application/json') + assert resp.status_code == 400 + + def test_normalizes_stderr_to_error_key(self, client): + """When ok=False but only stderr is set, it becomes the error key.""" + mock_ssm = MagicMock() + mock_ssm.install.return_value = {'ok': False, 'stderr': 'docker pull failed'} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.post('/api/store/services/nextcloud/install', + content_type='application/json') + data = json.loads(resp.data) + assert data.get('error') == 'docker pull failed' + assert resp.status_code == 400 + + def test_500_on_exception(self, client): + mock_ssm = MagicMock() + mock_ssm.install.side_effect = Exception('unexpected error') + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.post('/api/store/services/nextcloud/install', + content_type='application/json') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# DELETE /api/store/services/ +# --------------------------------------------------------------------------- + +class TestRemoveService: + def test_returns_200_on_success(self, client): + mock_ssm = MagicMock() + mock_ssm.remove.return_value = {'ok': True, 'message': 'Removed'} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.delete('/api/store/services/nextcloud') + assert resp.status_code == 200 + + def test_returns_404_when_not_installed(self, client): + mock_ssm = MagicMock() + mock_ssm.remove.return_value = {'ok': False, 'error': 'not installed'} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.delete('/api/store/services/nextcloud') + assert resp.status_code == 404 + + def test_passes_purge_flag(self, client): + mock_ssm = MagicMock() + mock_ssm.remove.return_value = {'ok': True} + with patch.object(app_module, 'service_store_manager', mock_ssm): + client.delete('/api/store/services/nextcloud?purge=true') + mock_ssm.remove.assert_called_once_with('nextcloud', purge_data=True) + + def test_purge_false_by_default(self, client): + mock_ssm = MagicMock() + mock_ssm.remove.return_value = {'ok': True} + with patch.object(app_module, 'service_store_manager', mock_ssm): + client.delete('/api/store/services/nextcloud') + mock_ssm.remove.assert_called_once_with('nextcloud', purge_data=False) + + def test_500_on_exception(self, client): + mock_ssm = MagicMock() + mock_ssm.remove.side_effect = Exception('docker error') + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.delete('/api/store/services/nextcloud') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/store/installed +# --------------------------------------------------------------------------- + +class TestGetInstalled: + def test_returns_200(self, client): + mock_cm = MagicMock() + mock_cm.get_installed_services.return_value = ['nextcloud'] + with patch.object(app_module, 'config_manager', mock_cm): + resp = client.get('/api/store/installed') + assert resp.status_code == 200 + + def test_returns_installed_list(self, client): + mock_cm = MagicMock() + mock_cm.get_installed_services.return_value = ['nextcloud', 'gitea'] + with patch.object(app_module, 'config_manager', mock_cm): + resp = client.get('/api/store/installed') + data = json.loads(resp.data) + assert 'installed' in data + assert 'nextcloud' in data['installed'] + assert 'gitea' in data['installed'] + + def test_returns_empty_when_nothing_installed(self, client): + mock_cm = MagicMock() + mock_cm.get_installed_services.return_value = [] + with patch.object(app_module, 'config_manager', mock_cm): + resp = client.get('/api/store/installed') + data = json.loads(resp.data) + assert data['installed'] == [] + + def test_500_on_exception(self, client): + mock_cm = MagicMock() + mock_cm.get_installed_services.side_effect = Exception('config error') + with patch.object(app_module, 'config_manager', mock_cm): + resp = client.get('/api/store/installed') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/store/refresh +# --------------------------------------------------------------------------- + +class TestRefreshIndex: + def test_returns_200(self, client): + mock_ssm = MagicMock() + mock_ssm._index_cache = {'data': 'old'} + mock_ssm._index_cache_time = 12345 + mock_ssm.list_services.return_value = {'available': [], 'installed': []} + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.post('/api/store/refresh', + content_type='application/json') + assert resp.status_code == 200 + + def test_clears_cache(self, client): + mock_ssm = MagicMock() + mock_ssm._index_cache = {'data': 'old'} + mock_ssm._index_cache_time = 12345 + mock_ssm.list_services.return_value = {} + with patch.object(app_module, 'service_store_manager', mock_ssm): + client.post('/api/store/refresh', content_type='application/json') + assert mock_ssm._index_cache is None + assert mock_ssm._index_cache_time == 0 + + def test_500_on_exception(self, client): + mock_ssm = MagicMock() + mock_ssm.list_services.side_effect = Exception('cache error') + with patch.object(app_module, 'service_store_manager', mock_ssm): + resp = client.post('/api/store/refresh', content_type='application/json') + assert resp.status_code == 500 diff --git a/tests/test_routes_services_catalog.py b/tests/test_routes_services_catalog.py new file mode 100644 index 0000000..91bb688 --- /dev/null +++ b/tests/test_routes_services_catalog.py @@ -0,0 +1,1043 @@ +""" +Tests for routes/services.py: +- /api/services/catalog (list, get, status, restart, reconfigure) +- /api/services/catalog//accounts (list, provision, deprovision, credentials) +- /api/services/bus/* (status, events, start, stop, restart) +- /api/logs/* endpoints +""" +import sys +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent / 'api')) + +import app as app_module +from app import app + + +@pytest.fixture +def client(): + app.config['TESTING'] = True + with app.test_client() as c: + yield c + + +def _make_registry(services): + reg = MagicMock() + reg.list_all = MagicMock(return_value=services) + reg.list_active = MagicMock(return_value=services) + reg.get = MagicMock(side_effect=lambda sid: next( + (s for s in services if s['id'] == sid), None)) + return reg + + +def _make_service(sid, kind='builtin'): + return {'id': sid, 'name': sid.title(), 'kind': kind, + 'subdomain': sid, 'capabilities': {}} + + +# --------------------------------------------------------------------------- +# GET /api/services/catalog +# --------------------------------------------------------------------------- + +class TestGetServicesCatalog: + def test_returns_200(self, client): + reg = _make_registry([]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog') + assert resp.status_code == 200 + + def test_returns_services_list(self, client): + reg = _make_registry([_make_service('email')]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog') + data = json.loads(resp.data) + assert 'services' in data + assert len(data['services']) == 1 + + def test_returns_empty_when_no_services(self, client): + reg = _make_registry([]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog') + data = json.loads(resp.data) + assert data['services'] == [] + + def test_500_on_exception(self, client): + reg = MagicMock() + reg.list_all.side_effect = Exception('registry error') + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/services/catalog/ +# --------------------------------------------------------------------------- + +class TestGetServiceCatalogEntry: + def test_returns_200_for_known_service(self, client): + reg = _make_registry([_make_service('email')]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog/email') + assert resp.status_code == 200 + + def test_returns_service_data(self, client): + reg = _make_registry([_make_service('email')]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog/email') + data = json.loads(resp.data) + assert data['id'] == 'email' + + def test_returns_404_for_unknown_service(self, client): + reg = _make_registry([]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog/nonexistent') + assert resp.status_code == 404 + + def test_404_includes_error_message(self, client): + reg = _make_registry([]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog/nonexistent') + data = json.loads(resp.data) + assert 'error' in data + + def test_500_on_exception(self, client): + reg = MagicMock() + reg.get.side_effect = Exception('registry error') + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/catalog/email') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/services/catalog//status +# --------------------------------------------------------------------------- + +class TestGetServiceContainerStatus: + def test_returns_200(self, client): + svc = _make_service('email') + reg = _make_registry([svc]) + composer = MagicMock() + composer.status_service.return_value = {'running': True, 'containers': []} + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.get('/api/services/catalog/email/status') + assert resp.status_code == 200 + + def test_returns_404_for_unknown_service(self, client): + reg = _make_registry([]) + composer = MagicMock() + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.get('/api/services/catalog/nonexistent/status') + assert resp.status_code == 404 + + def test_400_on_value_error(self, client): + svc = _make_service('email') + reg = _make_registry([svc]) + composer = MagicMock() + composer.status_service.side_effect = ValueError('bad service') + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.get('/api/services/catalog/email/status') + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# POST /api/services/catalog//restart +# --------------------------------------------------------------------------- + +class TestRestartServiceContainers: + def test_returns_200_on_success(self, client): + svc = _make_service('email') + reg = _make_registry([svc]) + composer = MagicMock() + composer.restart_service.return_value = {'ok': True, 'stdout': ''} + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/email/restart') + assert resp.status_code == 200 + + def test_returns_404_for_unknown_service(self, client): + reg = _make_registry([]) + composer = MagicMock() + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/nonexistent/restart') + assert resp.status_code == 404 + + def test_returns_500_when_restart_fails(self, client): + svc = _make_service('email') + reg = _make_registry([svc]) + composer = MagicMock() + composer.restart_service.return_value = {'ok': False, 'stderr': 'container error'} + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/email/restart') + assert resp.status_code == 500 + + def test_response_includes_message_on_success(self, client): + svc = _make_service('email') + reg = _make_registry([svc]) + composer = MagicMock() + composer.restart_service.return_value = {'ok': True} + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/email/restart') + data = json.loads(resp.data) + assert 'message' in data + + +# --------------------------------------------------------------------------- +# POST /api/services/catalog//reconfigure +# --------------------------------------------------------------------------- + +class TestReconfigureService: + def test_404_for_unknown_service(self, client): + reg = _make_registry([]) + composer = MagicMock() + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/nonexistent/reconfigure') + assert resp.status_code == 404 + + def test_400_for_builtin_service(self, client): + svc = _make_service('wireguard', kind='builtin') + reg = _make_registry([svc]) + composer = MagicMock() + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/wireguard/reconfigure') + assert resp.status_code == 400 + data = json.loads(resp.data) + assert 'Builtins' in data['error'] + + def test_400_when_no_compose_file(self, client): + svc = _make_service('myapp', kind='store') + reg = _make_registry([svc]) + composer = MagicMock() + composer.has_compose_file.return_value = False + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/myapp/reconfigure') + assert resp.status_code == 400 + + def test_200_on_success(self, client): + svc = _make_service('myapp', kind='store') + reg = _make_registry([svc]) + composer = MagicMock() + composer.has_compose_file.return_value = True + composer.up.return_value = {'ok': True, 'stdout': ''} + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/myapp/reconfigure') + assert resp.status_code == 200 + + def test_500_when_up_fails(self, client): + svc = _make_service('myapp', kind='store') + reg = _make_registry([svc]) + composer = MagicMock() + composer.has_compose_file.return_value = True + composer.up.return_value = {'ok': False, 'stderr': 'error'} + with patch.object(app_module, 'service_registry', reg), \ + patch.object(app_module, 'service_composer', composer): + resp = client.post('/api/services/catalog/myapp/reconfigure') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/services/catalog//accounts +# --------------------------------------------------------------------------- + +class TestListServiceAccounts: + def test_returns_200(self, client): + am = MagicMock() + am.list_accounts.return_value = ['alice', 'bob'] + with patch.object(app_module, 'account_manager', am): + resp = client.get('/api/services/catalog/email/accounts') + assert resp.status_code == 200 + + def test_returns_accounts_list(self, client): + am = MagicMock() + am.list_accounts.return_value = ['alice'] + with patch.object(app_module, 'account_manager', am): + resp = client.get('/api/services/catalog/email/accounts') + data = json.loads(resp.data) + assert data['accounts'] == ['alice'] + assert data['service_id'] == 'email' + + def test_500_on_exception(self, client): + am = MagicMock() + am.list_accounts.side_effect = Exception('fail') + with patch.object(app_module, 'account_manager', am): + resp = client.get('/api/services/catalog/email/accounts') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/services/catalog//accounts +# --------------------------------------------------------------------------- + +class TestProvisionServiceAccount: + def test_returns_201_on_success(self, client): + am = MagicMock() + am.provision.return_value = None + with patch.object(app_module, 'account_manager', am): + resp = client.post('/api/services/catalog/email/accounts', + data=json.dumps({'username': 'alice'}), + content_type='application/json') + assert resp.status_code == 201 + + def test_400_when_username_missing(self, client): + am = MagicMock() + with patch.object(app_module, 'account_manager', am): + resp = client.post('/api/services/catalog/email/accounts', + data=json.dumps({}), + content_type='application/json') + assert resp.status_code == 400 + + def test_400_on_value_error(self, client): + am = MagicMock() + am.provision.side_effect = ValueError('user already exists') + with patch.object(app_module, 'account_manager', am): + resp = client.post('/api/services/catalog/email/accounts', + data=json.dumps({'username': 'alice'}), + content_type='application/json') + assert resp.status_code == 400 + + def test_500_on_runtime_error(self, client): + am = MagicMock() + am.provision.side_effect = RuntimeError('docker exec failed') + with patch.object(app_module, 'account_manager', am): + resp = client.post('/api/services/catalog/email/accounts', + data=json.dumps({'username': 'alice'}), + content_type='application/json') + assert resp.status_code == 500 + + def test_response_shape(self, client): + am = MagicMock() + am.provision.return_value = None + with patch.object(app_module, 'account_manager', am): + resp = client.post('/api/services/catalog/email/accounts', + data=json.dumps({'username': 'alice'}), + content_type='application/json') + data = json.loads(resp.data) + assert data['service_id'] == 'email' + assert data['username'] == 'alice' + assert data['provisioned'] is True + + +# --------------------------------------------------------------------------- +# DELETE /api/services/catalog//accounts/ +# --------------------------------------------------------------------------- + +class TestDeprovisionServiceAccount: + def test_returns_200_on_success(self, client): + am = MagicMock() + am.deprovision.return_value = True + with patch.object(app_module, 'account_manager', am): + resp = client.delete('/api/services/catalog/email/accounts/alice') + assert resp.status_code == 200 + + def test_returns_500_when_deprovision_fails(self, client): + am = MagicMock() + am.deprovision.return_value = False + with patch.object(app_module, 'account_manager', am): + resp = client.delete('/api/services/catalog/email/accounts/alice') + assert resp.status_code == 500 + + def test_400_on_value_error(self, client): + am = MagicMock() + am.deprovision.side_effect = ValueError('not found') + with patch.object(app_module, 'account_manager', am): + resp = client.delete('/api/services/catalog/email/accounts/alice') + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# GET /api/services/catalog//accounts//credentials +# --------------------------------------------------------------------------- + +class TestGetServiceAccountCredentials: + def test_returns_200_with_creds(self, client): + am = MagicMock() + am.get_credentials.return_value = {'password': 'secret'} + with patch.object(app_module, 'account_manager', am): + resp = client.get('/api/services/catalog/email/accounts/alice/credentials') + assert resp.status_code == 200 + + def test_returns_404_when_not_provisioned(self, client): + am = MagicMock() + am.get_credentials.return_value = None + with patch.object(app_module, 'account_manager', am): + resp = client.get('/api/services/catalog/email/accounts/alice/credentials') + assert resp.status_code == 404 + + def test_returns_creds_in_response(self, client): + am = MagicMock() + am.get_credentials.return_value = {'password': 'mypass'} + with patch.object(app_module, 'account_manager', am): + resp = client.get('/api/services/catalog/email/accounts/alice/credentials') + data = json.loads(resp.data) + assert data['service_id'] == 'email' + assert data['username'] == 'alice' + assert 'password' in data + + +# --------------------------------------------------------------------------- +# GET /api/services/bus/status +# --------------------------------------------------------------------------- + +class TestServiceBusStatus: + def test_returns_200(self, client): + mock_sb = MagicMock() + mock_sb.get_service_status_summary.return_value = {} + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.get('/api/services/bus/status') + assert resp.status_code == 200 + + def test_500_on_exception(self, client): + mock_sb = MagicMock() + mock_sb.get_service_status_summary.side_effect = Exception('fail') + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.get('/api/services/bus/status') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/services/bus/services//start +# --------------------------------------------------------------------------- + +class TestServiceBusStart: + def test_returns_200_on_success(self, client): + mock_sb = MagicMock() + mock_sb.orchestrate_service_start.return_value = True + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.post('/api/services/bus/services/network/start') + assert resp.status_code == 200 + + def test_returns_500_on_failure(self, client): + mock_sb = MagicMock() + mock_sb.orchestrate_service_start.return_value = False + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.post('/api/services/bus/services/network/start') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/services/bus/services//stop +# --------------------------------------------------------------------------- + +class TestServiceBusStop: + def test_returns_200_on_success(self, client): + mock_sb = MagicMock() + mock_sb.orchestrate_service_stop.return_value = True + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.post('/api/services/bus/services/network/stop') + assert resp.status_code == 200 + + def test_returns_500_on_failure(self, client): + mock_sb = MagicMock() + mock_sb.orchestrate_service_stop.return_value = False + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.post('/api/services/bus/services/network/stop') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/services/bus/services//restart +# --------------------------------------------------------------------------- + +class TestServiceBusRestart: + def test_returns_200_on_success(self, client): + mock_sb = MagicMock() + mock_sb.orchestrate_service_restart.return_value = True + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.post('/api/services/bus/services/email/restart') + assert resp.status_code == 200 + + def test_returns_500_on_failure(self, client): + mock_sb = MagicMock() + mock_sb.orchestrate_service_restart.return_value = False + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.post('/api/services/bus/services/email/restart') + assert resp.status_code == 500 + + def test_500_on_exception(self, client): + mock_sb = MagicMock() + mock_sb.orchestrate_service_restart.side_effect = Exception('bus error') + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.post('/api/services/bus/services/email/restart') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/services/bus/events +# --------------------------------------------------------------------------- + +class TestServiceBusEvents: + def test_returns_200(self, client): + mock_sb = MagicMock() + mock_sb.get_event_history.return_value = [] + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.get('/api/services/bus/events') + assert resp.status_code == 200 + + def test_returns_list(self, client): + mock_sb = MagicMock() + mock_sb.get_event_history.return_value = [] + with patch.object(app_module, 'service_bus', mock_sb): + resp = client.get('/api/services/bus/events') + data = json.loads(resp.data) + assert isinstance(data, list) + + +# --------------------------------------------------------------------------- +# GET /api/logs/services/ +# --------------------------------------------------------------------------- + +class TestGetServiceLogs: + def test_returns_200(self, client): + mock_lm = MagicMock() + mock_lm.get_service_logs.return_value = ['line1', 'line2'] + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/services/network') + assert resp.status_code == 200 + + def test_returns_log_lines(self, client): + mock_lm = MagicMock() + mock_lm.get_service_logs.return_value = ['line1'] + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/services/network') + data = json.loads(resp.data) + assert 'logs' in data + assert data['service'] == 'network' + + def test_500_on_exception(self, client): + mock_lm = MagicMock() + mock_lm.get_service_logs.side_effect = Exception('fail') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/services/network') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# POST /api/logs/search +# --------------------------------------------------------------------------- + +class TestSearchLogs: + def test_returns_200(self, client): + mock_lm = MagicMock() + mock_lm.search_logs.return_value = [] + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/search', + data=json.dumps({'query': 'error'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_returns_results(self, client): + mock_lm = MagicMock() + mock_lm.search_logs.return_value = [{'message': 'error occurred'}] + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/search', + data=json.dumps({'query': 'error'}), + content_type='application/json') + data = json.loads(resp.data) + assert 'results' in data + assert 'count' in data + assert data['count'] == 1 + + +# --------------------------------------------------------------------------- +# GET /api/logs/statistics +# --------------------------------------------------------------------------- + +class TestLogStatistics: + def test_returns_200(self, client): + mock_lm = MagicMock() + mock_lm.get_log_statistics.return_value = {'total': 100} + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/statistics') + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# POST /api/logs/rotate +# --------------------------------------------------------------------------- + +class TestRotateLogs: + def test_returns_200(self, client): + mock_lm = MagicMock() + mock_lm.rotate_logs.return_value = None + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/rotate', + data=json.dumps({}), + content_type='application/json') + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# GET /api/logs/files +# --------------------------------------------------------------------------- + +class TestGetLogFiles: + def test_returns_200(self, client): + mock_lm = MagicMock() + mock_lm.get_all_log_file_infos.return_value = [] + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/files') + assert resp.status_code == 200 + + def test_returns_file_list(self, client): + mock_lm = MagicMock() + mock_lm.get_all_log_file_infos.return_value = [ + {'file': 'network.log', 'size': 1024, 'backup': False} + ] + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/files') + data = json.loads(resp.data) + assert len(data) == 1 + assert data[0]['file'] == 'network.log' + + def test_500_on_exception(self, client): + mock_lm = MagicMock() + mock_lm.get_all_log_file_infos.side_effect = Exception('disk error') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/files') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/logs/verbosity +# --------------------------------------------------------------------------- + +class TestGetLogVerbosity: + def test_returns_200(self, client): + mock_lm = MagicMock() + mock_lm.get_service_levels.return_value = {'network': 'INFO'} + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/verbosity') + assert resp.status_code == 200 + + def test_returns_service_levels(self, client): + mock_lm = MagicMock() + mock_lm.get_service_levels.return_value = {'network': 'DEBUG', 'email': 'INFO'} + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/verbosity') + data = json.loads(resp.data) + assert data['network'] == 'DEBUG' + assert data['email'] == 'INFO' + + def test_500_on_exception(self, client): + mock_lm = MagicMock() + mock_lm.get_service_levels.side_effect = Exception('fail') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/verbosity') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# PUT /api/logs/verbosity +# --------------------------------------------------------------------------- + +class TestSetLogVerbosity: + def test_returns_200(self, client): + import tempfile, os + mock_lm = MagicMock() + mock_lm.set_service_level.return_value = None + mock_lm.get_service_levels.return_value = {'network': 'DEBUG'} + with patch.object(app_module, 'log_manager', mock_lm): + with tempfile.TemporaryDirectory() as tmpdir: + with patch.dict(os.environ, {'CONFIG_DIR': tmpdir}): + resp = client.put('/api/logs/verbosity', + data=json.dumps({'network': 'DEBUG'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_calls_set_service_level(self, client): + import tempfile, os + mock_lm = MagicMock() + mock_lm.get_service_levels.return_value = {} + with patch.object(app_module, 'log_manager', mock_lm): + with tempfile.TemporaryDirectory() as tmpdir: + with patch.dict(os.environ, {'CONFIG_DIR': tmpdir}): + client.put('/api/logs/verbosity', + data=json.dumps({'network': 'DEBUG'}), + content_type='application/json') + mock_lm.set_service_level.assert_called_with('network', 'DEBUG') + + def test_500_on_exception(self, client): + mock_lm = MagicMock() + mock_lm.set_service_level.side_effect = Exception('fail') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.put('/api/logs/verbosity', + data=json.dumps({'network': 'DEBUG'}), + content_type='application/json') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/services/active +# --------------------------------------------------------------------------- + +class TestGetActiveServices: + def test_returns_200(self, client): + reg = _make_registry([_make_service('email')]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/active') + assert resp.status_code == 200 + + def test_returns_active_services_list(self, client): + reg = _make_registry([ + {'id': 'email', 'name': 'Email', 'subdomain': 'mail', 'capabilities': {'smtp': True}}, + ]) + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/active') + data = json.loads(resp.data) + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]['id'] == 'email' + + def test_500_on_exception(self, client): + reg = MagicMock() + reg.list_active.side_effect = Exception('error') + with patch.object(app_module, 'service_registry', reg): + resp = client.get('/api/services/active') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/services/status +# --------------------------------------------------------------------------- + +class TestGetAllServicesStatus: + def test_returns_200(self, client): + mock_sbus = MagicMock() + mock_sbus.list_services.return_value = [] + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/status') + assert resp.status_code == 200 + + def test_returns_status_for_each_service(self, client): + mock_svc = MagicMock() + mock_svc.get_status.return_value = { + 'status': 'online', 'running': True, 'timestamp': '2026-01-01T00:00:00' + } + mock_sbus = MagicMock() + mock_sbus.list_services.return_value = ['network'] + mock_sbus.get_service.return_value = mock_svc + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/status') + data = json.loads(resp.data) + assert 'network' in data + + def test_service_exception_sets_error_in_status(self, client): + mock_svc = MagicMock() + mock_svc.get_status.side_effect = Exception('service down') + mock_sbus = MagicMock() + mock_sbus.list_services.return_value = ['email'] + mock_sbus.get_service.return_value = mock_svc + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/status') + data = json.loads(resp.data) + assert 'email' in data + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# GET /api/services/connectivity +# --------------------------------------------------------------------------- + +class TestTestAllServicesConnectivity: + def test_returns_200(self, client): + mock_sbus = MagicMock() + mock_sbus.list_services.return_value = [] + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/connectivity') + assert resp.status_code == 200 + + def test_includes_connectivity_result(self, client): + mock_svc = MagicMock() + mock_svc.test_connectivity.return_value = {'success': True} + mock_sbus = MagicMock() + mock_sbus.list_services.return_value = ['network'] + mock_sbus.get_service.return_value = mock_svc + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/connectivity') + data = json.loads(resp.data) + assert 'network' in data + + def test_500_on_bus_exception(self, client): + mock_sbus = MagicMock() + mock_sbus.list_services.side_effect = Exception('bus fail') + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/connectivity') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/logs (backend logs endpoint) +# --------------------------------------------------------------------------- + +class TestGetBackendLogs: + def test_returns_404_when_log_file_missing(self, client): + # The default state: picell.log does not exist in the api/ directory + import os + import routes.services as svc_routes + # If the log file doesn't exist, it should return 404 + original = os.path.exists + def fake_exists(path): + if 'picell.log' in path: + return False + return original(path) + with patch('routes.services.os.path.exists', side_effect=fake_exists): + resp = client.get('/api/logs') + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# POST /api/logs/export +# --------------------------------------------------------------------------- + +class TestExportLogsEndpoint: + def test_returns_200(self, client): + mock_lm = MagicMock() + mock_lm.export_logs.return_value = '[]' + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/export', + data=json.dumps({'format': 'json'}), + content_type='application/json') + assert resp.status_code == 200 + + def test_returns_logs_and_format(self, client): + mock_lm = MagicMock() + mock_lm.export_logs.return_value = '[{"level":"INFO"}]' + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/export', + data=json.dumps({'format': 'json'}), + content_type='application/json') + data = json.loads(resp.data) + assert 'logs' in data + assert data['format'] == 'json' + + def test_500_on_exception(self, client): + mock_lm = MagicMock() + mock_lm.export_logs.side_effect = Exception('export failed') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/export', + data=json.dumps({'format': 'json'}), + content_type='application/json') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# GET /api/services/status — service-specific branches (wireguard, email, etc.) +# --------------------------------------------------------------------------- + +class TestGetAllServicesStatusServiceBranches: + """Cover the elif branches in get_all_services_status for each service type.""" + + def _status_for_service(self, client, service_name, raw_status): + mock_svc = MagicMock() + mock_svc.get_status.return_value = dict( + {'status': 'online', 'running': True, 'timestamp': '2026-01-01T00:00:00'}, + **raw_status + ) + mock_sbus = MagicMock() + mock_sbus.list_services.return_value = [service_name] + mock_sbus.get_service.return_value = mock_svc + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/status') + return json.loads(resp.data) + + def test_wireguard_branch_includes_peers_count(self, client): + data = self._status_for_service(client, 'wireguard', + {'peers_count': 3, 'interface': 'wg0'}) + assert data['wireguard'].get('peers_count') == 3 + assert data['wireguard'].get('interface') == 'wg0' + + def test_email_branch_includes_users_count(self, client): + data = self._status_for_service(client, 'email', + {'users_count': 5, 'domain': 'cell.local'}) + assert data['email'].get('users_count') == 5 + assert data['email'].get('domain') == 'cell.local' + + def test_calendar_branch_includes_calendars_count(self, client): + data = self._status_for_service(client, 'calendar', + {'users_count': 2, 'calendars_count': 4}) + assert data['calendar'].get('calendars_count') == 4 + + def test_files_branch_includes_storage_used(self, client): + data = self._status_for_service(client, 'files', + {'users_count': 1, 'total_storage_used': {'used': 100}}) + assert data['files'].get('storage_used') == {'used': 100} + + def test_routing_branch_includes_nat_rules_count(self, client): + data = self._status_for_service(client, 'routing', + {'nat_rules_count': 2, 'peer_routes_count': 1, + 'firewall_rules_count': 3}) + assert data['routing'].get('nat_rules_count') == 2 + + def test_vault_branch_includes_certificates_count(self, client): + data = self._status_for_service(client, 'vault', + {'certificates_count': 5, 'trusted_keys_count': 2}) + assert data['vault'].get('certificates_count') == 5 + + def test_non_dict_status_uses_string_fallback(self, client): + """When get_status returns a non-dict, it gets stored as a string+bool.""" + mock_svc = MagicMock() + mock_svc.get_status.return_value = 'running' + mock_sbus = MagicMock() + mock_sbus.list_services.return_value = ['network'] + mock_sbus.get_service.return_value = mock_svc + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/status') + data = json.loads(resp.data) + assert 'network' in data + + +# --------------------------------------------------------------------------- +# Error paths in catalog service management endpoints +# --------------------------------------------------------------------------- + +class TestServiceCatalogErrorPaths: + """Cover ValueError/Exception branches in catalog management endpoints.""" + + def test_get_service_container_status_value_error_returns_400(self, client): + mock_reg = MagicMock() + mock_reg.get.return_value = {'id': 'email', 'kind': 'builtin'} + mock_composer = MagicMock() + mock_composer.status_service.side_effect = ValueError('bad request') + with patch.object(app_module, 'service_registry', mock_reg), \ + patch.object(app_module, 'service_composer', mock_composer): + resp = client.get('/api/services/catalog/email/status') + assert resp.status_code == 400 + + def test_restart_service_containers_value_error_returns_400(self, client): + mock_reg = MagicMock() + mock_reg.get.return_value = {'id': 'email', 'kind': 'builtin'} + mock_composer = MagicMock() + mock_composer.restart_service.side_effect = ValueError('bad service') + with patch.object(app_module, 'service_registry', mock_reg), \ + patch.object(app_module, 'service_composer', mock_composer): + resp = client.post('/api/services/catalog/email/restart') + assert resp.status_code == 400 + + def test_restart_service_containers_exception_returns_500(self, client): + mock_reg = MagicMock() + mock_reg.get.return_value = {'id': 'email', 'kind': 'builtin'} + mock_composer = MagicMock() + mock_composer.restart_service.side_effect = Exception('docker down') + with patch.object(app_module, 'service_registry', mock_reg), \ + patch.object(app_module, 'service_composer', mock_composer): + resp = client.post('/api/services/catalog/email/restart') + assert resp.status_code == 500 + + def test_reconfigure_service_value_error_returns_400(self, client): + mock_reg = MagicMock() + mock_reg.get.return_value = {'id': 'myapp', 'kind': 'store'} + mock_composer = MagicMock() + mock_composer.has_compose_file.return_value = True + mock_composer.up.side_effect = ValueError('invalid config') + with patch.object(app_module, 'service_registry', mock_reg), \ + patch.object(app_module, 'service_composer', mock_composer): + resp = client.post('/api/services/catalog/myapp/reconfigure') + assert resp.status_code == 400 + + def test_reconfigure_service_exception_returns_500(self, client): + mock_reg = MagicMock() + mock_reg.get.return_value = {'id': 'myapp', 'kind': 'store'} + mock_composer = MagicMock() + mock_composer.has_compose_file.return_value = True + mock_composer.up.side_effect = Exception('compose fail') + with patch.object(app_module, 'service_registry', mock_reg), \ + patch.object(app_module, 'service_composer', mock_composer): + resp = client.post('/api/services/catalog/myapp/reconfigure') + assert resp.status_code == 500 + + def test_provision_service_account_exception_returns_500(self, client): + mock_am = MagicMock() + mock_am.provision.side_effect = Exception('db failure') + with patch.object(app_module, 'account_manager', mock_am): + resp = client.post('/api/services/catalog/email/accounts', + data=json.dumps({'username': 'alice'}), + content_type='application/json') + assert resp.status_code == 500 + + def test_deprovision_service_account_value_error_returns_400(self, client): + mock_am = MagicMock() + mock_am.deprovision.side_effect = ValueError('user not found') + with patch.object(app_module, 'account_manager', mock_am): + resp = client.delete('/api/services/catalog/email/accounts/alice') + assert resp.status_code == 400 + + def test_deprovision_service_account_exception_returns_500(self, client): + mock_am = MagicMock() + mock_am.deprovision.side_effect = Exception('db down') + with patch.object(app_module, 'account_manager', mock_am): + resp = client.delete('/api/services/catalog/email/accounts/alice') + assert resp.status_code == 500 + + def test_get_service_account_credentials_exception_returns_500(self, client): + mock_am = MagicMock() + mock_am.get_credentials.side_effect = Exception('db fail') + with patch.object(app_module, 'account_manager', mock_am): + resp = client.get('/api/services/catalog/email/accounts/alice/credentials') + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# Exception paths for bus event/start/stop and log endpoints +# --------------------------------------------------------------------------- + +class TestServiceBusEndpointExceptions: + """Cover exception paths for service bus and log endpoints.""" + + def test_get_service_bus_events_exception_returns_500(self, client): + mock_sbus = MagicMock() + mock_sbus.get_event_history.side_effect = Exception('bus crash') + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.get('/api/services/bus/events') + assert resp.status_code == 500 + + def test_start_service_exception_returns_500(self, client): + mock_sbus = MagicMock() + mock_sbus.orchestrate_service_start.side_effect = Exception('crash') + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.post('/api/services/bus/services/email/start') + assert resp.status_code == 500 + + def test_stop_service_exception_returns_500(self, client): + mock_sbus = MagicMock() + mock_sbus.orchestrate_service_stop.side_effect = Exception('crash') + with patch.object(app_module, 'service_bus', mock_sbus): + resp = client.post('/api/services/bus/services/email/stop') + assert resp.status_code == 500 + + def test_search_logs_exception_returns_500(self, client): + mock_lm = MagicMock() + mock_lm.search_logs.side_effect = Exception('search fail') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/search', + data='{}', + content_type='application/json') + assert resp.status_code == 500 + + def test_get_log_statistics_exception_returns_500(self, client): + mock_lm = MagicMock() + mock_lm.get_log_statistics.side_effect = Exception('stats fail') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.get('/api/logs/statistics') + assert resp.status_code == 500 + + def test_rotate_logs_exception_returns_500(self, client): + mock_lm = MagicMock() + mock_lm.rotate_logs.side_effect = Exception('rotate fail') + with patch.object(app_module, 'log_manager', mock_lm): + resp = client.post('/api/logs/rotate', + data='{}', + content_type='application/json') + assert resp.status_code == 500 diff --git a/tests/test_routing_manager.py b/tests/test_routing_manager.py index c96ce96..850c7b6 100644 --- a/tests/test_routing_manager.py +++ b/tests/test_routing_manager.py @@ -1,149 +1,924 @@ -import sys -from pathlib import Path - -# Add api directory to path -api_dir = Path(__file__).parent.parent / 'api' -sys.path.insert(0, str(api_dir)) -import unittest -import tempfile -import shutil -import os -from unittest.mock import patch, MagicMock -from routing_manager import RoutingManager -import json - -class TestRoutingManager(unittest.TestCase): - def setUp(self): - self.test_dir = tempfile.mkdtemp() - self.data_dir = os.path.join(self.test_dir, 'data') - self.config_dir = os.path.join(self.test_dir, 'config') - os.makedirs(self.data_dir, exist_ok=True) - os.makedirs(self.config_dir, exist_ok=True) - self.manager = RoutingManager(data_dir=self.data_dir, config_dir=self.config_dir) - - def tearDown(self): - shutil.rmtree(self.test_dir) - - def test_initialization(self): - # Test RoutingManager initialization and config creation - self.assertTrue(os.path.exists(self.manager.routing_dir)) - self.assertTrue(os.path.exists(self.manager.rules_file)) - # Check that rules file contains default structure - with open(self.manager.rules_file) as f: - rules = json.load(f) - self.assertIn('nat_rules', rules) - self.assertIn('peer_routes', rules) - self.assertIn('exit_nodes', rules) - self.assertIn('bridge_routes', rules) - self.assertIn('split_routes', rules) - self.assertIn('firewall_rules', rules) - self.assertIsInstance(rules['nat_rules'], list) - self.assertIsInstance(rules['peer_routes'], dict) - self.assertIsInstance(rules['exit_nodes'], list) - self.assertIsInstance(rules['bridge_routes'], list) - self.assertIsInstance(rules['split_routes'], list) - self.assertIsInstance(rules['firewall_rules'], list) - - @patch.object(RoutingManager, '_apply_nat_rule', return_value=True) - @patch.object(RoutingManager, '_remove_nat_rule', return_value=True) - def test_add_and_remove_nat_rule(self, mock_remove_nat, mock_apply_nat): - # Add a valid NAT rule - result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0') - self.assertTrue(result) - # Check that the rule is persisted - with open(self.manager.rules_file) as f: - rules = json.load(f) - self.assertEqual(len(rules['nat_rules']), 1) - rule = rules['nat_rules'][0] - self.assertEqual(rule['source_network'], '10.0.0.0/24') - self.assertEqual(rule['target_interface'], 'eth0') - self.assertEqual(rule['nat_type'], 'MASQUERADE') - self.assertTrue(rule['enabled']) - # Remove the NAT rule - rule_id = rule['id'] - result = self.manager.remove_nat_rule(rule_id) - self.assertTrue(result) - with open(self.manager.rules_file) as f: - rules = json.load(f) - self.assertEqual(len(rules['nat_rules']), 0) - # Test invalid NAT rule (bad CIDR) - result = self.manager.add_nat_rule('bad-cidr', 'eth0') - self.assertFalse(result) - # Test invalid NAT rule (bad interface) - result = self.manager.add_nat_rule('10.0.0.0/24', '') - self.assertFalse(result) - # Test invalid NAT rule (bad nat_type) - result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0', nat_type='INVALID') - self.assertFalse(result) - # Test invalid NAT rule (bad protocol) - result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0', protocol='INVALID') - self.assertFalse(result) - - @patch.object(RoutingManager, '_apply_peer_route', return_value=True) - @patch.object(RoutingManager, '_remove_peer_route', return_value=True) - def test_add_and_remove_peer_route(self, mock_remove_peer, mock_apply_peer): - # Add a valid peer route - allowed_networks = ['10.0.0.0/24'] - result = self.manager.add_peer_route('peer1', '10.0.0.2', allowed_networks) - self.assertTrue(result) - # Check that the route is persisted - with open(self.manager.rules_file) as f: - rules = json.load(f) - self.assertIn('peer1', rules['peer_routes']) - route = rules['peer_routes']['peer1'] - self.assertEqual(route['peer_name'], 'peer1') - self.assertEqual(route['peer_ip'], '10.0.0.2') - self.assertEqual(route['allowed_networks'], allowed_networks) - self.assertEqual(route['route_type'], 'lan') - self.assertTrue(route['enabled']) - # Remove the peer route - result = self.manager.remove_peer_route('peer1') - self.assertTrue(result) - with open(self.manager.rules_file) as f: - rules = json.load(f) - self.assertNotIn('peer1', rules['peer_routes']) - # Test invalid peer route (bad peer_name) - result = self.manager.add_peer_route('', '10.0.0.2', allowed_networks) - self.assertFalse(result) - # Test invalid peer route (bad peer_ip) - result = self.manager.add_peer_route('peer2', '', allowed_networks) - self.assertFalse(result) - # Test invalid peer route (bad allowed_networks) - result = self.manager.add_peer_route('peer3', '10.0.0.3', ['bad-cidr']) - self.assertFalse(result) - # Test invalid peer route (bad route_type) - result = self.manager.add_peer_route('peer4', '10.0.0.4', allowed_networks, route_type='invalid') - self.assertFalse(result) - - def test_add_exit_node(self): - pass # Test adding exit node configuration - - def test_add_bridge_route(self): - pass # Test adding bridge route between peers - - def test_add_split_route(self): - pass # Test adding split routing rule - - def test_add_firewall_rule(self): - pass # Test adding firewall rule - - def test_get_routing_status(self): - pass # Test routing status and monitoring - - def test_test_routing_connectivity(self): - pass # Test routing connectivity - - def test_get_routing_logs(self): - pass # Test log collection - - def test_error_handling(self): - pass # Test error handling and edge cases - - def test_subprocess_command_execution(self): - pass # Test subprocess command execution (mocked) - - def test_route_parsing_and_analysis(self): - pass # Test route parsing and analysis - -if __name__ == '__main__': - unittest.main() \ No newline at end of file +import sys +import subprocess +from pathlib import Path + +# Add api directory to path +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) +import unittest +import tempfile +import shutil +import os +from unittest.mock import patch, MagicMock, call +from routing_manager import RoutingManager +import json + +class TestRoutingManager(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.data_dir = os.path.join(self.test_dir, 'data') + self.config_dir = os.path.join(self.test_dir, 'config') + os.makedirs(self.data_dir, exist_ok=True) + os.makedirs(self.config_dir, exist_ok=True) + self.manager = RoutingManager(data_dir=self.data_dir, config_dir=self.config_dir) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def test_initialization(self): + # Test RoutingManager initialization and config creation + self.assertTrue(os.path.exists(self.manager.routing_dir)) + self.assertTrue(os.path.exists(self.manager.rules_file)) + # Check that rules file contains default structure + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertIn('nat_rules', rules) + self.assertIn('peer_routes', rules) + self.assertIn('exit_nodes', rules) + self.assertIn('bridge_routes', rules) + self.assertIn('split_routes', rules) + self.assertIn('firewall_rules', rules) + self.assertIsInstance(rules['nat_rules'], list) + self.assertIsInstance(rules['peer_routes'], dict) + self.assertIsInstance(rules['exit_nodes'], list) + self.assertIsInstance(rules['bridge_routes'], list) + self.assertIsInstance(rules['split_routes'], list) + self.assertIsInstance(rules['firewall_rules'], list) + + @patch.object(RoutingManager, '_apply_nat_rule', return_value=True) + @patch.object(RoutingManager, '_remove_nat_rule', return_value=True) + def test_add_and_remove_nat_rule(self, mock_remove_nat, mock_apply_nat): + # Add a valid NAT rule + result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0') + self.assertTrue(result) + # Check that the rule is persisted + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(len(rules['nat_rules']), 1) + rule = rules['nat_rules'][0] + self.assertEqual(rule['source_network'], '10.0.0.0/24') + self.assertEqual(rule['target_interface'], 'eth0') + self.assertEqual(rule['nat_type'], 'MASQUERADE') + self.assertTrue(rule['enabled']) + # Remove the NAT rule + rule_id = rule['id'] + result = self.manager.remove_nat_rule(rule_id) + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(len(rules['nat_rules']), 0) + # Test invalid NAT rule (bad CIDR) + result = self.manager.add_nat_rule('bad-cidr', 'eth0') + self.assertFalse(result) + # Test invalid NAT rule (bad interface) + result = self.manager.add_nat_rule('10.0.0.0/24', '') + self.assertFalse(result) + # Test invalid NAT rule (bad nat_type) + result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0', nat_type='INVALID') + self.assertFalse(result) + # Test invalid NAT rule (bad protocol) + result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0', protocol='INVALID') + self.assertFalse(result) + + @patch.object(RoutingManager, '_apply_peer_route', return_value=True) + @patch.object(RoutingManager, '_remove_peer_route', return_value=True) + def test_add_and_remove_peer_route(self, mock_remove_peer, mock_apply_peer): + # Add a valid peer route + allowed_networks = ['10.0.0.0/24'] + result = self.manager.add_peer_route('peer1', '10.0.0.2', allowed_networks) + self.assertTrue(result) + # Check that the route is persisted + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertIn('peer1', rules['peer_routes']) + route = rules['peer_routes']['peer1'] + self.assertEqual(route['peer_name'], 'peer1') + self.assertEqual(route['peer_ip'], '10.0.0.2') + self.assertEqual(route['allowed_networks'], allowed_networks) + self.assertEqual(route['route_type'], 'lan') + self.assertTrue(route['enabled']) + # Remove the peer route + result = self.manager.remove_peer_route('peer1') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertNotIn('peer1', rules['peer_routes']) + # Test invalid peer route (bad peer_name) + result = self.manager.add_peer_route('', '10.0.0.2', allowed_networks) + self.assertFalse(result) + # Test invalid peer route (bad peer_ip) + result = self.manager.add_peer_route('peer2', '', allowed_networks) + self.assertFalse(result) + # Test invalid peer route (bad allowed_networks) + result = self.manager.add_peer_route('peer3', '10.0.0.3', ['bad-cidr']) + self.assertFalse(result) + # Test invalid peer route (bad route_type) + result = self.manager.add_peer_route('peer4', '10.0.0.4', allowed_networks, route_type='invalid') + self.assertFalse(result) + + @patch.object(RoutingManager, '_apply_exit_node') + def test_add_exit_node_valid(self, mock_apply): + result = self.manager.add_exit_node('peer1', '10.0.0.2') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(len(rules['exit_nodes']), 1) + node = rules['exit_nodes'][0] + self.assertEqual(node['peer_name'], 'peer1') + self.assertEqual(node['peer_ip'], '10.0.0.2') + self.assertTrue(node['enabled']) + mock_apply.assert_called_once() + + @patch.object(RoutingManager, '_apply_exit_node') + def test_add_exit_node_with_allowed_domains(self, mock_apply): + result = self.manager.add_exit_node('peer1', '10.0.0.2', allowed_domains=['example.com']) + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(rules['exit_nodes'][0]['allowed_domains'], ['example.com']) + + def test_add_exit_node_invalid_peer_name(self): + result = self.manager.add_exit_node('bad name!', '10.0.0.2') + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_exit_node_invalid_ip(self): + result = self.manager.add_exit_node('peer1', 'not-an-ip') + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_exit_node_invalid_domains(self): + result = self.manager.add_exit_node('peer1', '10.0.0.2', allowed_domains='not-a-list') + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_exit_node_invalid_domain_format(self): + result = self.manager.add_exit_node('peer1', '10.0.0.2', allowed_domains=['bad domain!']) + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + @patch.object(RoutingManager, '_apply_bridge_route') + def test_add_bridge_route_valid(self, mock_apply): + result = self.manager.add_bridge_route('src-peer', '192.168.1.0/24', ['10.0.0.0/24']) + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(len(rules['bridge_routes']), 1) + route = rules['bridge_routes'][0] + self.assertEqual(route['source_peer'], 'src-peer') + self.assertEqual(route['target_peer'], '192.168.1.0/24') + mock_apply.assert_called_once() + + def test_add_bridge_route_invalid_source(self): + result = self.manager.add_bridge_route('bad name!', '192.168.1.0/24', ['10.0.0.0/24']) + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_bridge_route_invalid_target(self): + result = self.manager.add_bridge_route('src-peer', 'not-a-network', ['10.0.0.0/24']) + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_bridge_route_empty_allowed_networks(self): + result = self.manager.add_bridge_route('src-peer', '192.168.1.0/24', []) + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_bridge_route_invalid_network_in_list(self): + result = self.manager.add_bridge_route('src-peer', '192.168.1.0/24', ['not-a-cidr']) + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + @patch.object(RoutingManager, '_apply_split_route') + def test_add_split_route_valid(self, mock_apply): + result = self.manager.add_split_route('10.0.0.0/24', '10.0.0.1') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(len(rules['split_routes']), 1) + route = rules['split_routes'][0] + self.assertEqual(route['network'], '10.0.0.0/24') + self.assertEqual(route['exit_peer'], '10.0.0.1') + mock_apply.assert_called_once() + + @patch.object(RoutingManager, '_apply_split_route') + def test_add_split_route_with_fallback(self, mock_apply): + result = self.manager.add_split_route('10.0.0.0/24', '10.0.0.1', fallback_peer='10.0.0.2') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(rules['split_routes'][0]['fallback_peer'], '10.0.0.2') + + def test_add_split_route_invalid_network(self): + result = self.manager.add_split_route('not-a-cidr', '10.0.0.1') + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_split_route_invalid_exit_peer(self): + result = self.manager.add_split_route('10.0.0.0/24', 'not-an-ip') + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + def test_add_split_route_invalid_fallback(self): + result = self.manager.add_split_route('10.0.0.0/24', '10.0.0.1', fallback_peer='not-ip') + self.assertIsInstance(result, dict) + self.assertFalse(result.get('success', True)) + + @patch.object(RoutingManager, '_apply_firewall_rule') + def test_add_firewall_rule_valid(self, mock_apply): + result = self.manager.add_firewall_rule('FORWARD', '10.0.0.0/24', '192.168.1.0/24') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(len(rules['firewall_rules']), 1) + rule = rules['firewall_rules'][0] + self.assertEqual(rule['rule_type'], 'FORWARD') + self.assertEqual(rule['source'], '10.0.0.0/24') + self.assertEqual(rule['destination'], '192.168.1.0/24') + self.assertEqual(rule['action'], 'ACCEPT') + mock_apply.assert_called_once() + + @patch.object(RoutingManager, '_apply_firewall_rule') + def test_add_firewall_rule_with_port(self, mock_apply): + result = self.manager.add_firewall_rule( + 'INPUT', '10.0.0.0/24', '192.168.1.0/24', protocol='TCP', port='80') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + rule = rules['firewall_rules'][0] + self.assertEqual(rule['protocol'], 'TCP') + self.assertEqual(rule['port'], '80') + + @patch.object(RoutingManager, '_apply_firewall_rule') + def test_add_firewall_rule_with_port_range(self, mock_apply): + result = self.manager.add_firewall_rule( + 'INPUT', '10.0.0.0/24', '192.168.1.0/24', port_range='1000-2000') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(rules['firewall_rules'][0]['port_range'], '1000-2000') + + def test_add_firewall_rule_invalid_type(self): + result = self.manager.add_firewall_rule('BADCHAIN', '10.0.0.0/24', '192.168.1.0/24') + self.assertFalse(result) + + def test_add_firewall_rule_invalid_source(self): + result = self.manager.add_firewall_rule('FORWARD', 'not-cidr', '192.168.1.0/24') + self.assertFalse(result) + + def test_add_firewall_rule_invalid_destination(self): + result = self.manager.add_firewall_rule('FORWARD', '10.0.0.0/24', 'not-cidr') + self.assertFalse(result) + + def test_add_firewall_rule_invalid_action(self): + result = self.manager.add_firewall_rule('FORWARD', '10.0.0.0/24', '192.168.1.0/24', action='JUMP') + self.assertFalse(result) + + def test_add_firewall_rule_invalid_protocol(self): + result = self.manager.add_firewall_rule('FORWARD', '10.0.0.0/24', '192.168.1.0/24', protocol='SCTP') + self.assertFalse(result) + + def test_add_firewall_rule_invalid_port_value(self): + result = self.manager.add_firewall_rule( + 'FORWARD', '10.0.0.0/24', '192.168.1.0/24', port='99999') + self.assertFalse(result) + + def test_add_firewall_rule_port_not_number(self): + result = self.manager.add_firewall_rule( + 'FORWARD', '10.0.0.0/24', '192.168.1.0/24', port='abc') + self.assertFalse(result) + + def test_add_firewall_rule_invalid_port_range_format(self): + result = self.manager.add_firewall_rule( + 'FORWARD', '10.0.0.0/24', '192.168.1.0/24', port_range='abc-def') + self.assertFalse(result) + + @patch.object(RoutingManager, '_apply_firewall_rule') + @patch('subprocess.run') + def test_remove_firewall_rule(self, mock_sub, mock_apply): + mock_proc = MagicMock() + mock_proc.returncode = 0 + mock_sub.return_value = mock_proc + self.manager.add_firewall_rule('FORWARD', '10.0.0.0/24', '192.168.1.0/24') + with open(self.manager.rules_file) as f: + rules = json.load(f) + rule_id = rules['firewall_rules'][0]['id'] + result = self.manager.remove_firewall_rule(rule_id) + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(len(rules['firewall_rules']), 0) + + @patch.object(RoutingManager, '_apply_firewall_rule') + def test_remove_firewall_rule_not_found(self, mock_apply): + result = self.manager.remove_firewall_rule('nonexistent_id') + self.assertFalse(result) + + @patch.object(RoutingManager, '_apply_firewall_rule') + @patch('subprocess.run') + def test_remove_firewall_rule_with_tcp_port(self, mock_sub, mock_apply): + """remove_firewall_rule builds correct iptables -D cmd for TCP+port rule.""" + mock_sub.return_value = MagicMock(returncode=0) + self.manager.add_firewall_rule( + 'INPUT', '10.0.0.0/24', '192.168.1.0/24', protocol='TCP', port='443') + with open(self.manager.rules_file) as f: + rules = json.load(f) + rule_id = rules['firewall_rules'][0]['id'] + result = self.manager.remove_firewall_rule(rule_id) + self.assertTrue(result) + # Check subprocess was called with iptables -D + call_args = mock_sub.call_args[0][0] + self.assertIn('iptables', call_args) + self.assertIn('-D', call_args) + + def test_get_routing_status_empty(self): + status = self.manager.get_routing_status() + self.assertIn('nat_rules_count', status) + self.assertEqual(status['nat_rules_count'], 0) + self.assertIn('firewall_rules_count', status) + self.assertIn('peer_routes_count', status) + self.assertIn('exit_nodes_count', status) + self.assertIn('bridge_routes_count', status) + self.assertIn('split_routes_count', status) + self.assertIn('routing_table', status) + self.assertIn('active_rules', status) + + @patch.object(RoutingManager, '_apply_nat_rule') + @patch.object(RoutingManager, '_apply_firewall_rule') + @patch.object(RoutingManager, '_apply_peer_route') + def test_get_routing_status_counts(self, mock_peer, mock_fw, mock_nat): + self.manager.add_nat_rule('10.0.0.0/24', 'eth0') + self.manager.add_firewall_rule('FORWARD', '10.0.0.0/24', '192.168.1.0/24') + self.manager.add_peer_route('p1', '10.0.0.2', ['10.0.0.0/24']) + status = self.manager.get_routing_status() + self.assertEqual(status['nat_rules_count'], 1) + self.assertEqual(status['firewall_rules_count'], 1) + self.assertEqual(status['peer_routes_count'], 1) + + def test_get_routing_status_corrupted_file(self): + # Corrupt the rules file to trigger exception + with open(self.manager.rules_file, 'w') as f: + f.write('{corrupt') + status = self.manager.get_routing_status() + # Should return safe defaults + self.assertEqual(status['nat_rules_count'], 0) + + def test_get_nat_rules_empty(self): + rules = self.manager.get_nat_rules() + self.assertEqual(rules, []) + + @patch.object(RoutingManager, '_apply_nat_rule') + def test_get_nat_rules_returns_list(self, mock_apply): + self.manager.add_nat_rule('10.0.0.0/24', 'eth0') + rules = self.manager.get_nat_rules() + self.assertEqual(len(rules), 1) + self.assertEqual(rules[0]['source_network'], '10.0.0.0/24') + + def test_get_peer_routes_empty(self): + routes = self.manager.get_peer_routes() + self.assertEqual(routes, []) + + @patch.object(RoutingManager, '_apply_peer_route') + def test_get_peer_routes_returns_list(self, mock_apply): + self.manager.add_peer_route('peer1', '10.0.0.2', ['10.0.0.0/24']) + routes = self.manager.get_peer_routes() + self.assertEqual(len(routes), 1) + self.assertEqual(routes[0]['peer_name'], 'peer1') + + def test_get_firewall_rules_empty(self): + rules = self.manager.get_firewall_rules() + self.assertEqual(rules, []) + + @patch.object(RoutingManager, '_apply_peer_route') + def test_update_peer_ip_updates_and_reapplies(self, mock_apply): + self.manager.add_peer_route('peer1', '10.0.0.2', ['10.0.0.0/24']) + result = self.manager.update_peer_ip('peer1', '10.0.0.99') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + self.assertEqual(rules['peer_routes']['peer1']['peer_ip'], '10.0.0.99') + + def test_update_peer_ip_nonexistent_peer(self): + result = self.manager.update_peer_ip('nobody', '10.0.0.99') + self.assertFalse(result) + + def test_remove_peer_route_nonexistent(self): + result = self.manager.remove_peer_route('nobody') + self.assertFalse(result) + + def test_get_status_returns_dict(self): + status = self.manager.get_status() + self.assertIsInstance(status, dict) + self.assertIn('running', status) + self.assertIn('status', status) + self.assertIn('nat_rules_count', status) + + def test_start_and_stop_service(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0) + result = self.manager.start() + self.assertTrue(result) + self.assertTrue(self.manager._service_running) + result = self.manager.stop() + self.assertTrue(result) + self.assertFalse(self.manager._service_running) + + def test_start_persists_service_state(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0) + self.manager.start() + # State file should exist now + self.assertTrue(os.path.exists(self.manager._state_file)) + with open(self.manager._state_file) as f: + state = json.load(f) + self.assertTrue(state['running']) + + def test_stop_persists_service_state(self): + self.manager.stop() + with open(self.manager._state_file) as f: + state = json.load(f) + self.assertFalse(state['running']) + + def test_restart_service(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0) + with patch('time.sleep'): + result = self.manager.restart() + self.assertTrue(result) + + def test_test_connectivity_returns_dict(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0, stdout='', stderr='') + result = self.manager.test_connectivity() + self.assertIsInstance(result, dict) + self.assertIn('routing_functionality', result) + self.assertIn('iptables_access', result) + self.assertIn('network_interfaces', result) + self.assertIn('routing_table_access', result) + + def test_test_routing_connectivity_returns_results(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0, stdout='OK', stderr='') + results = self.manager.test_routing_connectivity('8.8.8.8') + self.assertIn('ping', results) + self.assertIn('traceroute', results) + + def test_test_routing_connectivity_with_via_peer(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0, stdout='OK', stderr='') + results = self.manager.test_routing_connectivity('8.8.8.8', via_peer='10.0.0.1') + self.assertIn('peer_route', results) + + def test_test_routing_connectivity_subprocess_exception(self): + with patch('subprocess.run', side_effect=Exception('timeout')): + results = self.manager.test_routing_connectivity('8.8.8.8') + self.assertIn('ping', results) + self.assertFalse(results['ping']['success']) + + def test_get_routing_logs(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0, stdout='log line', stderr='') + logs = self.manager.get_routing_logs() + self.assertIsInstance(logs, dict) + self.assertIn('routes', logs) + + def test_parse_route_basic(self): + parsed = self.manager._parse_route('10.0.0.0/24 via 172.20.0.1 dev eth0 metric 100') + self.assertEqual(parsed['destination'], '10.0.0.0/24') + self.assertEqual(parsed['via'], '172.20.0.1') + self.assertEqual(parsed['dev'], 'eth0') + self.assertEqual(parsed['metric'], '100') + + def test_parse_route_no_via(self): + parsed = self.manager._parse_route('10.0.0.0/24 dev eth0') + self.assertEqual(parsed['destination'], '10.0.0.0/24') + self.assertEqual(parsed['dev'], 'eth0') + self.assertEqual(parsed['via'], '') + + def test_parse_route_empty_string(self): + parsed = self.manager._parse_route('') + self.assertEqual(parsed['destination'], '') + + def test_validate_cidr_valid(self): + self.assertTrue(self.manager._validate_cidr('10.0.0.0/24')) + self.assertTrue(self.manager._validate_cidr('192.168.1.0/24')) + self.assertTrue(self.manager._validate_cidr('172.16.0.0/12')) + self.assertTrue(self.manager._validate_cidr('0.0.0.0/0')) + + def test_validate_cidr_invalid(self): + self.assertFalse(self.manager._validate_cidr('not-a-cidr')) + self.assertFalse(self.manager._validate_cidr('')) + self.assertFalse(self.manager._validate_cidr('10.0.0.300/24')) + + def test_load_service_state_from_file(self): + """State is restored from the file on the next instantiation.""" + self.manager._service_running = True + self.manager._save_service_state() + new_manager = RoutingManager(data_dir=self.data_dir, config_dir=self.config_dir) + self.assertTrue(new_manager._service_running) + + def test_load_service_state_no_file(self): + """Without a state file, service defaults to running=True.""" + if os.path.exists(self.manager._state_file): + os.remove(self.manager._state_file) + new_manager = RoutingManager(data_dir=self.data_dir, config_dir=self.config_dir) + self.assertTrue(new_manager._service_running) + + def test_get_live_iptables(self): + with patch('subprocess.run') as mock_sub: + mock_sub.return_value = MagicMock(returncode=0, stdout='Chain INPUT', stderr='') + result = self.manager.get_live_iptables() + self.assertIn('filter', result) + self.assertIn('nat', result) + + def test_get_live_iptables_subprocess_exception(self): + with patch('subprocess.run', side_effect=Exception('no docker')): + result = self.manager.get_live_iptables() + self.assertIn('filter', result) + self.assertIn('nat', result) + + @patch.object(RoutingManager, '_apply_nat_rule') + def test_nat_rule_snat_type(self, mock_apply): + result = self.manager.add_nat_rule( + '10.0.0.0/24', 'eth0', nat_type='SNAT', + internal_ip='10.0.0.5', internal_port='8080') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + rule = rules['nat_rules'][0] + self.assertEqual(rule['nat_type'], 'SNAT') + + @patch.object(RoutingManager, '_apply_nat_rule') + def test_nat_rule_dnat_type(self, mock_apply): + result = self.manager.add_nat_rule( + '10.0.0.0/24', 'eth0', nat_type='DNAT', + internal_ip='10.0.0.5', external_port='80', internal_port='8080') + self.assertTrue(result) + with open(self.manager.rules_file) as f: + rules = json.load(f) + rule = rules['nat_rules'][0] + self.assertEqual(rule['nat_type'], 'DNAT') + + @patch.object(RoutingManager, '_apply_nat_rule') + def test_nat_rule_tcp_protocol(self, mock_apply): + result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0', protocol='TCP') + self.assertTrue(result) + + @patch.object(RoutingManager, '_apply_nat_rule') + def test_nat_rule_udp_protocol(self, mock_apply): + result = self.manager.add_nat_rule('10.0.0.0/24', 'eth0', protocol='UDP') + self.assertTrue(result) + + @patch.object(RoutingManager, '_apply_peer_route') + def test_peer_route_exit_type(self, mock_apply): + result = self.manager.add_peer_route('p1', '10.0.0.2', ['10.0.0.0/24'], route_type='exit') + self.assertTrue(result) + + @patch.object(RoutingManager, '_apply_peer_route') + def test_peer_route_bridge_type(self, mock_apply): + result = self.manager.add_peer_route('p1', '10.0.0.2', ['10.0.0.0/24'], route_type='bridge') + self.assertTrue(result) + + @patch.object(RoutingManager, '_apply_peer_route') + def test_peer_route_split_type(self, mock_apply): + result = self.manager.add_peer_route('p1', '10.0.0.2', ['10.0.0.0/24'], route_type='split') + self.assertTrue(result) + + def test_ensure_config_exists_is_idempotent(self): + """Calling _ensure_config_exists twice does not raise.""" + self.manager._ensure_config_exists() + self.manager._ensure_config_exists() + self.assertTrue(os.path.exists(self.manager.rules_file)) + + def test_save_rules_failure_is_silent(self): + """_save_rules failure (e.g. permission error) doesn't raise.""" + with patch('builtins.open', side_effect=OSError('disk full')): + # Should not raise + self.manager._save_rules({'nat_rules': [], 'peer_routes': {}}) + + def test_load_rules_failure_returns_empty_dict(self): + """_load_rules failure returns {}.""" + with open(self.manager.rules_file, 'w') as f: + f.write('{corrupt') + result = self.manager._load_rules() + self.assertEqual(result, {}) + + def test_test_iptables_access_not_found(self): + """FileNotFoundError from iptables returns success=True (dev mode).""" + with patch('subprocess.run', side_effect=FileNotFoundError()): + result = self.manager._test_iptables_access() + self.assertTrue(result['success']) + + def test_test_network_interfaces_not_found(self): + """FileNotFoundError from ip link show returns success=True (dev mode).""" + with patch('subprocess.run', side_effect=FileNotFoundError()): + result = self.manager._test_network_interfaces() + self.assertTrue(result['success']) + + def test_test_routing_table_not_found(self): + """FileNotFoundError from ip route show returns success=True (dev mode).""" + with patch('subprocess.run', side_effect=FileNotFoundError()): + result = self.manager._test_routing_table_access() + self.assertTrue(result['success']) + + def test_is_routing_service_running_reflects_state(self): + self.manager._service_running = True + self.assertTrue(self.manager._is_routing_service_running()) + self.manager._service_running = False + self.assertFalse(self.manager._is_routing_service_running()) + + @patch('subprocess.run') + def test_apply_nat_rule_masquerade(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + rule = { + 'source_network': '10.0.0.0/24', + 'target_interface': 'eth0', + 'masquerade': True, + 'nat_type': 'MASQUERADE', + 'protocol': 'ALL', + 'internal_ip': None, + 'external_port': None, + 'internal_port': None, + } + self.manager._apply_nat_rule(rule) + mock_sub.assert_called_once() + call_args = mock_sub.call_args[0][0] + self.assertIn('iptables', call_args) + self.assertIn('MASQUERADE', call_args) + + @patch('subprocess.run') + def test_apply_nat_rule_dnat(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + rule = { + 'source_network': '10.0.0.0/24', + 'target_interface': 'eth0', + 'masquerade': False, + 'nat_type': 'DNAT', + 'protocol': 'TCP', + 'internal_ip': '192.168.1.10', + 'external_port': '80', + 'internal_port': '8080', + } + self.manager._apply_nat_rule(rule) + mock_sub.assert_called_once() + call_args = mock_sub.call_args[0][0] + self.assertIn('DNAT', call_args) + + @patch('subprocess.run') + def test_apply_nat_rule_snat(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + rule = { + 'source_network': '10.0.0.0/24', + 'target_interface': 'eth0', + 'masquerade': False, + 'nat_type': 'SNAT', + 'protocol': 'TCP', + 'internal_ip': '192.168.1.10', + 'external_port': None, + 'internal_port': '8080', + } + self.manager._apply_nat_rule(rule) + mock_sub.assert_called_once() + call_args = mock_sub.call_args[0][0] + self.assertIn('SNAT', call_args) + + @patch('subprocess.run') + def test_apply_firewall_rule_with_port(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + rule = { + 'rule_type': 'INPUT', + 'source': '10.0.0.0/24', + 'destination': '192.168.1.0/24', + 'action': 'ACCEPT', + 'protocol': 'TCP', + 'port': '443', + 'port_range': None, + } + self.manager._apply_firewall_rule(rule) + mock_sub.assert_called_once() + call_args = mock_sub.call_args[0][0] + self.assertIn('--dport', call_args) + self.assertIn('443', call_args) + + @patch('subprocess.run') + def test_apply_firewall_rule_with_port_range(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + rule = { + 'rule_type': 'INPUT', + 'source': '10.0.0.0/24', + 'destination': '192.168.1.0/24', + 'action': 'ACCEPT', + 'protocol': 'ALL', + 'port': None, + 'port_range': '1000-2000', + } + self.manager._apply_firewall_rule(rule) + call_args = mock_sub.call_args[0][0] + # Port range should be passed as 1000:2000 for iptables + self.assertIn('1000:2000', call_args) + + def test_parse_proc_net_route_valid_data(self): + """_parse_proc_net_route parses /proc/net/route format correctly.""" + import tempfile + route_content = ( + "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" + "eth0\t00000000\t0101A8C0\t0003\t0\t0\t100\t00000000\t0\t0\t0\n" + "eth0\t000011AC\t00000000\t0001\t0\t0\t0\tF0FFFFFF\t0\t0\t0\n" + ) + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write(route_content) + fname = f.name + try: + routes = self.manager._parse_proc_net_route(fname) + self.assertIsInstance(routes, list) + # Should have parsed the two non-header lines + self.assertEqual(len(routes), 2) + # First entry should be a default route + self.assertIn('route', routes[0]) + finally: + os.unlink(fname) + + +class TestRoutingManagerInternalMethods(unittest.TestCase): + """Tests for internal methods that are harder to reach via public API.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.data_dir = os.path.join(self.test_dir, 'data') + self.config_dir = os.path.join(self.test_dir, 'config') + os.makedirs(self.data_dir, exist_ok=True) + os.makedirs(self.config_dir, exist_ok=True) + self.manager = RoutingManager(self.data_dir, self.config_dir) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + # ── _apply_peer_route ───────────────────────────────────────────────── + + @patch('subprocess.run') + def test_apply_peer_route_calls_ip_route(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + self.manager._apply_peer_route({ + 'peer_name': 'alice', + 'peer_ip': '10.0.0.2', + 'allowed_networks': ['192.168.100.0/24', '10.1.0.0/16'] + }) + self.assertEqual(mock_sub.call_count, 2) + + @patch('subprocess.run', side_effect=subprocess.CalledProcessError(1, 'ip')) + def test_apply_peer_route_exception_is_silent(self, _mock): + # Should not raise + self.manager._apply_peer_route({ + 'peer_name': 'alice', + 'peer_ip': '10.0.0.2', + 'allowed_networks': ['192.168.100.0/24'] + }) + + # ── _remove_peer_route ──────────────────────────────────────────────── + + @patch('subprocess.run') + def test_remove_peer_route_calls_ip_route_del(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + self.manager._remove_peer_route('alice') + mock_sub.assert_called_once() + call_args = mock_sub.call_args[0][0] + self.assertIn('del', call_args) + + @patch('subprocess.run', side_effect=FileNotFoundError('ip not found')) + def test_remove_peer_route_exception_is_silent(self, _mock): + self.manager._remove_peer_route('alice') # Should not raise + + # ── _apply_exit_node ────────────────────────────────────────────────── + + @patch('subprocess.run') + def test_apply_exit_node_adds_default_route(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + self.manager._apply_exit_node({ + 'peer_name': 'exitnode', + 'peer_ip': '10.0.0.2' + }) + call_args = mock_sub.call_args[0][0] + self.assertIn('default', call_args) + + @patch('subprocess.run', side_effect=Exception('permission denied')) + def test_apply_exit_node_exception_is_silent(self, _mock): + self.manager._apply_exit_node({'peer_name': 'exit', 'peer_ip': '10.0.0.2'}) + + # ── _apply_bridge_route ─────────────────────────────────────────────── + + @patch('subprocess.run') + def test_apply_bridge_route_adds_forward_rules(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + self.manager._apply_bridge_route({ + 'source_peer': 'src', + 'target_peer': 'dst', + 'allowed_networks': ['10.1.0.0/16', '10.2.0.0/16'] + }) + self.assertEqual(mock_sub.call_count, 2) + + @patch('subprocess.run', side_effect=Exception('fail')) + def test_apply_bridge_route_exception_is_silent(self, _mock): + self.manager._apply_bridge_route({ + 'source_peer': 'src', 'target_peer': 'dst', + 'allowed_networks': ['10.1.0.0/16'] + }) + + # ── _apply_split_route ──────────────────────────────────────────────── + + @patch('subprocess.run') + def test_apply_split_route_adds_specific_route(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + self.manager._apply_split_route({ + 'network': '10.100.0.0/16', + 'exit_peer': 'exit_wg', + 'priority': 100 + }) + self.assertEqual(mock_sub.call_count, 1) + call_args = mock_sub.call_args[0][0] + self.assertIn('10.100.0.0/16', call_args) + + @patch('subprocess.run', side_effect=Exception('fail')) + def test_apply_split_route_exception_is_silent(self, _mock): + self.manager._apply_split_route({ + 'network': '10.100.0.0/16', 'exit_peer': 'wg0', 'priority': 100 + }) + + # ── _remove_nat_rule ────────────────────────────────────────────────── + + @patch('subprocess.run') + def test_remove_nat_rule_calls_iptables_D(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0) + self.manager._remove_nat_rule('rule_abc') + call_args = mock_sub.call_args[0][0] + self.assertIn('-D', call_args) + self.assertIn('POSTROUTING', call_args) + + @patch('subprocess.run') + def test_remove_nat_rule_not_found_is_silent(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=1) + self.manager._remove_nat_rule('nonexistent_rule') # Should not raise + + @patch('subprocess.run', side_effect=FileNotFoundError('iptables')) + def test_remove_nat_rule_exception_is_silent(self, _mock): + self.manager._remove_nat_rule('rule_x') + + # ── get_routing_logs ────────────────────────────────────────────────── + + @patch('subprocess.run') + def test_get_routing_logs_returns_dict(self, mock_sub): + mock_sub.return_value = MagicMock(returncode=0, stdout='some output\n') + result = self.manager.get_routing_logs() + self.assertIsInstance(result, dict) + self.assertIn('routes', result) + + @patch('subprocess.run', side_effect=Exception('system error')) + def test_get_routing_logs_outer_exception_returns_error_dict(self, _mock): + # The outer try will succeed but inner dmesg calls will fail + result = self.manager.get_routing_logs() + self.assertIsInstance(result, dict) + + # ── remove_firewall_rule exception path ─────────────────────────────── + + @patch('subprocess.run') + def test_remove_firewall_rule_exception_returns_false(self, mock_sub): + mock_sub.side_effect = Exception('unexpected error') + # The outer try in remove_firewall_rule should catch and return False + # This requires the rules file to have a matching rule that triggers subprocess + result = self.manager.remove_firewall_rule('nonexistent_rule_id') + # nonexistent rule -> returns False (not found path) + self.assertFalse(result) + + # ── _get_routing_table ──────────────────────────────────────────────── + + @patch('subprocess.run') + def test_get_routing_table_returns_list(self, mock_sub): + mock_sub.return_value = MagicMock( + returncode=0, + stdout='default via 192.168.1.1 dev eth0\n10.0.0.0/24 dev wg0 proto kernel\n' + ) + result = self.manager._get_routing_table() + self.assertIsInstance(result, list) + + @patch('subprocess.run', side_effect=FileNotFoundError('docker')) + def test_get_routing_table_fallback_exception_returns_empty(self, _mock): + """When both /proc/1/net/route and docker exec fail, return empty list.""" + with patch('builtins.open', side_effect=FileNotFoundError('/proc/1/net/route')): + result = self.manager._get_routing_table() + self.assertEqual(result, []) + + # ── start with sysctl exception ─────────────────────────────────────── + + @patch('subprocess.run', side_effect=FileNotFoundError('sysctl')) + def test_start_continues_when_sysctl_not_found(self, _mock): + result = self.manager.start() + self.assertTrue(result) + self.assertTrue(self.manager._service_running) + + @patch('subprocess.run', side_effect=subprocess.CalledProcessError(1, 'sysctl')) + def test_start_continues_when_sysctl_fails(self, _mock): + result = self.manager.start() + self.assertTrue(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_service_bus_extra.py b/tests/test_service_bus_extra.py new file mode 100644 index 0000000..56a2629 --- /dev/null +++ b/tests/test_service_bus_extra.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +Additional tests for ServiceBus covering missed branches: +- unsubscribe_from_event: handler-not-found ValueError +- call_service: method raises exception (publishes ERROR_OCCURRED, re-raises) +- orchestrate_service_start/stop: with container manager registered +- orchestrate_service_restart: exception path +- _event_loop: processing events through handlers, handler exception +- add_service_dependency: new service key branch +- remove_service_dependency: dependency not found (ValueError branch) +- get_service_status_summary: service without get_status, service that raises +""" + +import sys +import time +import unittest +from pathlib import Path +from unittest.mock import Mock, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from service_bus import ServiceBus, EventType, Event + + +class TestUnsubscribeNotFound(unittest.TestCase): + def setUp(self): + self.bus = ServiceBus() + + def test_unsubscribe_handler_not_registered_does_not_raise(self): + # unsubscribing a handler that was never added should not raise + def handler(event): + pass + # should not raise ValueError + self.bus.unsubscribe_from_event(EventType.SERVICE_STARTED, handler) + + def test_unsubscribe_removes_correct_handler(self): + received = [] + def h1(event): received.append('h1') + def h2(event): received.append('h2') + self.bus.subscribe_to_event(EventType.SERVICE_STARTED, h1) + self.bus.subscribe_to_event(EventType.SERVICE_STARTED, h2) + self.bus.unsubscribe_from_event(EventType.SERVICE_STARTED, h1) + handlers = self.bus.event_handlers[EventType.SERVICE_STARTED] + self.assertNotIn(h1, handlers) + self.assertIn(h2, handlers) + + +class TestCallServiceException(unittest.TestCase): + def setUp(self): + self.bus = ServiceBus() + + def test_call_service_method_raises_publishes_error_and_reraises(self): + mock_svc = Mock() + mock_svc.failing_method.side_effect = RuntimeError("boom") + self.bus.register_service('svc', mock_svc) + + with self.assertRaises(RuntimeError): + self.bus.call_service('svc', 'failing_method') + + def test_call_service_error_is_added_to_queue(self): + mock_svc = Mock() + mock_svc.boom.side_effect = ValueError("bad value") + self.bus.register_service('svc', mock_svc) + + try: + self.bus.call_service('svc', 'boom') + except ValueError: + pass + + # The ERROR_OCCURRED event was put onto the queue + self.assertFalse(self.bus.event_queue.empty()) + + +class TestOrchestrateWithContainers(unittest.TestCase): + def setUp(self): + self.bus = ServiceBus() + + def _register_container_manager(self, start_return=True, stop_return=True): + cm = Mock() + cm.start_container.return_value = start_return + cm.stop_container.return_value = stop_return + self.bus.register_service('container', cm) + return cm + + def test_orchestrate_start_wireguard_starts_containers(self): + cm = self._register_container_manager(start_return=True) + # wireguard service has containers but is not registered as a service + result = self.bus.orchestrate_service_start('wireguard') + self.assertTrue(result) + cm.start_container.assert_called_with('cell-wireguard') + + def test_orchestrate_start_container_failure_returns_false(self): + cm = self._register_container_manager(start_return=False) + result = self.bus.orchestrate_service_start('wireguard') + self.assertFalse(result) + + def test_orchestrate_start_no_container_manager_returns_false(self): + # email has containers but 'container' manager is not registered + result = self.bus.orchestrate_service_start('email') + self.assertFalse(result) + + def test_orchestrate_stop_wireguard_stops_containers(self): + cm = self._register_container_manager(stop_return=True) + result = self.bus.orchestrate_service_stop('wireguard') + self.assertTrue(result) + cm.stop_container.assert_called_with('cell-wireguard') + + def test_orchestrate_stop_container_failure_returns_false(self): + cm = self._register_container_manager(stop_return=False) + result = self.bus.orchestrate_service_stop('wireguard') + self.assertFalse(result) + + def test_orchestrate_stop_no_container_manager_returns_false(self): + result = self.bus.orchestrate_service_stop('email') + self.assertFalse(result) + + def test_orchestrate_restart_exception_returns_false(self): + # Make stop raise an exception to trigger the except clause + cm = Mock() + cm.stop_container.side_effect = RuntimeError("docker gone") + self.bus.register_service('container', cm) + result = self.bus.orchestrate_service_restart('wireguard') + self.assertFalse(result) + + +class TestEventLoopProcessing(unittest.TestCase): + def test_event_loop_calls_handler(self): + bus = ServiceBus() + received = [] + + def handler(event): + received.append(event) + + bus.subscribe_to_event(EventType.SERVICE_STARTED, handler) + bus.start() + try: + bus.publish_event(EventType.SERVICE_STARTED, 'src', {'x': 1}) + time.sleep(0.3) + self.assertEqual(len(received), 1) + self.assertEqual(received[0].source, 'src') + finally: + bus.stop() + + def test_event_loop_handler_exception_does_not_stop_loop(self): + bus = ServiceBus() + received = [] + + def bad_handler(event): + raise RuntimeError("handler crash") + + def good_handler(event): + received.append(event) + + bus.subscribe_to_event(EventType.SERVICE_STARTED, bad_handler) + bus.subscribe_to_event(EventType.SERVICE_STARTED, good_handler) + bus.start() + try: + bus.publish_event(EventType.SERVICE_STARTED, 'src', {}) + time.sleep(0.3) + # Loop continues; good_handler was also called + self.assertEqual(len(received), 1) + finally: + bus.stop() + + def test_event_loop_history_trimmed_at_max(self): + bus = ServiceBus() + bus.max_history = 3 + bus.start() + try: + for i in range(5): + bus.publish_event(EventType.SERVICE_STARTED, f'src{i}', {}) + time.sleep(0.3) + self.assertLessEqual(len(bus.event_history), 3) + finally: + bus.stop() + + +class TestServiceDependencyEdgeCases(unittest.TestCase): + def setUp(self): + self.bus = ServiceBus() + + def test_add_dependency_creates_new_entry_for_unknown_service(self): + self.bus.add_service_dependency('brand_new_service', 'network') + self.assertIn('brand_new_service', self.bus.service_dependencies) + self.assertIn('network', self.bus.service_dependencies['brand_new_service']) + + def test_remove_dependency_not_found_does_not_raise(self): + self.bus.add_service_dependency('svc', 'dep1') + # removing a dependency that was never added should not raise + self.bus.remove_service_dependency('svc', 'dep_nonexistent') + + def test_remove_dependency_for_unknown_service_does_not_raise(self): + # service never had any dependencies registered + self.bus.remove_service_dependency('ghost_service', 'dep') + + +class TestServiceStatusSummaryBranches(unittest.TestCase): + def setUp(self): + self.bus = ServiceBus() + + def test_summary_service_without_get_status(self): + # A service object without a get_status attribute + svc = object() + self.bus.register_service('plain_obj', svc) + summary = self.bus.get_service_status_summary() + self.assertIn('plain_obj', summary['services']) + self.assertEqual( + summary['services']['plain_obj']['status'], + {'status': 'unknown'} + ) + + def test_summary_service_get_status_raises(self): + svc = Mock() + svc.get_status.side_effect = RuntimeError("status unavailable") + self.bus.register_service('broken_svc', svc) + summary = self.bus.get_service_status_summary() + self.assertIn('broken_svc', summary['services']) + self.assertIn('error', summary['services']['broken_svc']['status']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_setup_route.py b/tests/test_setup_route.py index 863aa1a..d344c09 100644 --- a/tests/test_setup_route.py +++ b/tests/test_setup_route.py @@ -131,3 +131,208 @@ def test_complete_setup_fires_identity_changed_on_success(client): mock_sbus.publish_event.assert_called_once() event_args = mock_sbus.publish_event.call_args assert event_args[0][1] == 'setup' + + +# --------------------------------------------------------------------------- +# GET /api/setup/status +# --------------------------------------------------------------------------- + +def test_get_setup_status_returns_200_when_incomplete(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.get_setup_status.return_value = {'step': 'cell_name', 'complete': False} + with patch('app.setup_manager', mock_sm): + resp = client.get('/api/setup/status') + assert resp.status_code == 200 + + +def test_get_setup_status_returns_410_when_already_complete(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = True + with patch('app.setup_manager', mock_sm): + resp = client.get('/api/setup/status') + assert resp.status_code == 410 + + +def test_get_setup_status_returns_setup_data(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.get_setup_status.return_value = {'step': 'cell_name', 'options': {}} + with patch('app.setup_manager', mock_sm): + resp = client.get('/api/setup/status') + data = json.loads(resp.data) + assert 'step' in data + + +# --------------------------------------------------------------------------- +# POST /api/setup/validate +# --------------------------------------------------------------------------- + +def _post_validate(client, payload): + return client.post( + '/api/setup/validate', + data=json.dumps(payload), + content_type='application/json', + ) + + +def test_validate_cell_name_valid(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_cell_name.return_value = [] + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, {'step': 'cell_name', 'data': {'cell_name': 'mycel'}}) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data['valid'] is True + assert data['errors'] == [] + + +def test_validate_cell_name_invalid(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_cell_name.return_value = ['Name too short'] + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, {'step': 'cell_name', 'data': {'cell_name': 'a'}}) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data['valid'] is False + assert len(data['errors']) > 0 + + +def test_validate_password_valid(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_password.return_value = [] + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, {'step': 'password', 'data': {'password': 'StrongPass1!'}}) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data['valid'] is True + + +def test_validate_password_invalid(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_password.return_value = ['Too short'] + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, {'step': 'password', 'data': {'password': 'weak'}}) + data = json.loads(resp.data) + assert data['valid'] is False + + +def test_validate_pic_ngo_available_when_available(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_cell_name.return_value = [] + with patch('app.setup_manager', mock_sm), \ + patch('routes.setup._check_pic_ngo_available', return_value=True): + resp = _post_validate(client, { + 'step': 'pic_ngo_available', 'data': {'cell_name': 'mycel'}}) + data = json.loads(resp.data) + assert data['available'] is True + + +def test_validate_pic_ngo_available_when_taken(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_cell_name.return_value = [] + with patch('app.setup_manager', mock_sm), \ + patch('routes.setup._check_pic_ngo_available', return_value=False): + resp = _post_validate(client, { + 'step': 'pic_ngo_available', 'data': {'cell_name': 'mycel'}}) + data = json.loads(resp.data) + assert data['available'] is False + + +def test_validate_pic_ngo_name_errors_block_check(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_cell_name.return_value = ['Invalid name'] + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, { + 'step': 'pic_ngo_available', 'data': {'cell_name': 'a'}}) + data = json.loads(resp.data) + assert data['available'] is False + + +def test_validate_pic_ngo_service_unreachable_returns_503(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + mock_sm.validate_cell_name.return_value = [] + with patch('app.setup_manager', mock_sm), \ + patch('routes.setup._check_pic_ngo_available', side_effect=Exception('timeout')): + resp = _post_validate(client, { + 'step': 'pic_ngo_available', 'data': {'cell_name': 'mycel'}}) + assert resp.status_code == 503 + data = json.loads(resp.data) + assert data['available'] is False + + +def test_validate_cloudflare_token_valid(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + with patch('app.setup_manager', mock_sm), \ + patch('routes.setup._verify_cloudflare_token', return_value=True): + resp = _post_validate(client, { + 'step': 'cloudflare_token', 'data': {'token': 'mytoken'}}) + data = json.loads(resp.data) + assert data['valid'] is True + + +def test_validate_cloudflare_token_missing(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, { + 'step': 'cloudflare_token', 'data': {'token': ''}}) + data = json.loads(resp.data) + assert data['valid'] is False + + +def test_validate_cloudflare_token_invalid(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + with patch('app.setup_manager', mock_sm), \ + patch('routes.setup._verify_cloudflare_token', return_value=False): + resp = _post_validate(client, { + 'step': 'cloudflare_token', 'data': {'token': 'badtoken'}}) + data = json.loads(resp.data) + assert data['valid'] is False + + +def test_validate_duckdns_token_valid(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + with patch('app.setup_manager', mock_sm), \ + patch('routes.setup._verify_duckdns_token', return_value=True): + resp = _post_validate(client, { + 'step': 'duckdns_token', 'data': {'subdomain': 'mycel', 'token': 'abc'}}) + data = json.loads(resp.data) + assert data['valid'] is True + + +def test_validate_duckdns_token_missing_fields(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, { + 'step': 'duckdns_token', 'data': {'subdomain': '', 'token': ''}}) + data = json.loads(resp.data) + assert data['valid'] is False + + +def test_validate_unknown_step_returns_400(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = False + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, {'step': 'unknown_step', 'data': {}}) + assert resp.status_code == 400 + + +def test_validate_returns_410_when_setup_complete(client): + mock_sm = MagicMock() + mock_sm.is_setup_complete.return_value = True + with patch('app.setup_manager', mock_sm): + resp = _post_validate(client, {'step': 'cell_name', 'data': {'cell_name': 'x'}}) + assert resp.status_code == 410