""" Tests for PUT /api/config input validation (400 paths). These are the highest-risk untested paths: the only server-side guard against bad subnet/port values entering persistent config. """ import json import sys import os import unittest from unittest.mock import patch, MagicMock sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'api')) def _make_client(): from app import app app.config['TESTING'] = True return app.test_client() def _put(client, payload): return client.put( '/api/config', data=json.dumps(payload), content_type='application/json', ) # --------------------------------------------------------------------------- # ip_range validation # --------------------------------------------------------------------------- class TestIpRangeValidation(unittest.TestCase): def setUp(self): self.client = _make_client() def test_non_rfc1918_returns_400(self): r = _put(self.client, {'ip_range': '1.2.3.0/24'}) self.assertEqual(r.status_code, 400) body = json.loads(r.data) self.assertIn('error', body) self.assertIn('RFC-1918', body['error']) def test_172_0_subnet_returns_400(self): # 172.0.0.0/24 is NOT in 172.16.0.0/12 — was the bug on the dev machine r = _put(self.client, {'ip_range': '172.0.0.0/24'}) self.assertEqual(r.status_code, 400) def test_172_15_subnet_returns_400(self): # One prefix below the 172.16.0.0/12 boundary r = _put(self.client, {'ip_range': '172.15.0.0/24'}) self.assertEqual(r.status_code, 400) def test_172_32_subnet_returns_400(self): # One prefix above the 172.31.255.255 boundary r = _put(self.client, {'ip_range': '172.32.0.0/24'}) self.assertEqual(r.status_code, 400) def test_public_ip_returns_400(self): r = _put(self.client, {'ip_range': '8.8.0.0/16'}) self.assertEqual(r.status_code, 400) def test_172_16_exact_boundary_accepted(self): # 172.16.0.0/12 is the exact lower boundary — must be valid r = _put(self.client, {'ip_range': '172.16.0.0/12'}) # 200 or 202 — just not 400 self.assertNotEqual(r.status_code, 400) def test_10_network_accepted(self): r = _put(self.client, {'ip_range': '10.0.0.0/8'}) self.assertNotEqual(r.status_code, 400) def test_192_168_network_accepted(self): r = _put(self.client, {'ip_range': '192.168.0.0/16'}) self.assertNotEqual(r.status_code, 400) def test_invalid_cidr_syntax_returns_400(self): r = _put(self.client, {'ip_range': 'not-a-cidr'}) self.assertEqual(r.status_code, 400) # --------------------------------------------------------------------------- # Port range validation # --------------------------------------------------------------------------- class TestPortValidation(unittest.TestCase): def setUp(self): self.client = _make_client() def test_dns_port_zero_returns_400(self): r = _put(self.client, {'network': {'dns_port': 0}}) self.assertEqual(r.status_code, 400) body = json.loads(r.data) self.assertIn('dns_port', body.get('error', '')) def test_dns_port_65536_returns_400(self): r = _put(self.client, {'network': {'dns_port': 65536}}) self.assertEqual(r.status_code, 400) def test_wireguard_port_zero_returns_400(self): r = _put(self.client, {'wireguard': {'port': 0}}) self.assertEqual(r.status_code, 400) def test_wireguard_port_65536_returns_400(self): r = _put(self.client, {'wireguard': {'port': 65536}}) self.assertEqual(r.status_code, 400) def test_wireguard_port_1_accepted(self): r = _put(self.client, {'wireguard': {'port': 1}}) self.assertNotEqual(r.status_code, 400) def test_wireguard_port_65535_accepted(self): r = _put(self.client, {'wireguard': {'port': 65535}}) self.assertNotEqual(r.status_code, 400) def test_email_smtp_port_zero_returns_400(self): r = _put(self.client, {'email': {'smtp_port': 0}}) self.assertEqual(r.status_code, 400) def test_calendar_port_negative_returns_400(self): r = _put(self.client, {'calendar': {'port': -1}}) self.assertEqual(r.status_code, 400) # --------------------------------------------------------------------------- # WireGuard address validation # --------------------------------------------------------------------------- class TestWireguardAddressValidation(unittest.TestCase): def setUp(self): self.client = _make_client() def test_bad_wg_address_returns_400(self): r = _put(self.client, {'wireguard': {'address': 'not-an-ip'}}) self.assertEqual(r.status_code, 400) body = json.loads(r.data) self.assertIn('wireguard.address', body.get('error', '')) def test_ip_without_prefix_returns_400(self): r = _put(self.client, {'wireguard': {'address': '10.0.0.1'}}) self.assertEqual(r.status_code, 400) def test_valid_wg_address_accepted(self): r = _put(self.client, {'wireguard': {'address': '10.0.0.1/24'}}) self.assertNotEqual(r.status_code, 400) # --------------------------------------------------------------------------- # Body validation # --------------------------------------------------------------------------- class TestBodyValidation(unittest.TestCase): def setUp(self): self.client = _make_client() def test_no_body_returns_400(self): r = self.client.put('/api/config', content_type='application/json') self.assertEqual(r.status_code, 400) def test_empty_body_returns_400(self): r = self.client.put('/api/config', data='', content_type='application/json') self.assertEqual(r.status_code, 400) def test_valid_cell_name_change_returns_200(self): r = _put(self.client, {'cell_name': 'testcell'}) self.assertEqual(r.status_code, 200) # --------------------------------------------------------------------------- # Domain conflict validation # --------------------------------------------------------------------------- class TestDomainConflictValidation(unittest.TestCase): """Changing this cell's domain to one already used by a connected cell → 409.""" def setUp(self): self.client = _make_client() def test_domain_matching_connected_cell_returns_409(self): """PUT /api/config with domain='other.cell' conflicts with a connected cell.""" connected = [{'cell_name': 'remote', 'domain': 'other.cell', 'vpn_subnet': '10.5.0.0/24', 'dns_ip': '10.5.0.1'}] with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = connected r = _put(self.client, {'domain': 'other.cell'}) self.assertEqual(r.status_code, 409) import json data = json.loads(r.data) self.assertIn('remote', data['error']) def test_domain_not_matching_any_cell_is_accepted(self): """PUT /api/config with a domain not used by any connected cell → 200.""" connected = [{'cell_name': 'remote', 'domain': 'other.cell', 'vpn_subnet': '10.5.0.0/24', 'dns_ip': '10.5.0.1'}] with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = connected r = _put(self.client, {'domain': 'unique.cell'}) self.assertNotEqual(r.status_code, 409) def test_domain_no_connected_cells_is_accepted(self): """PUT /api/config with domain change when no cells are connected → 200.""" with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = [] r = _put(self.client, {'domain': 'any.cell'}) self.assertNotEqual(r.status_code, 409) # --------------------------------------------------------------------------- # WireGuard address subnet conflict validation # --------------------------------------------------------------------------- class TestWireGuardAddressConflictValidation(unittest.TestCase): """Changing wireguard.address to a subnet overlapping a connected cell → 409.""" def setUp(self): self.client = _make_client() self._connected = [{'cell_name': 'remote', 'domain': 'remote.cell', 'vpn_subnet': '10.5.0.0/24', 'dns_ip': '10.5.0.1'}] def test_wg_address_overlapping_connected_cell_returns_409(self): with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = self._connected r = _put(self.client, {'wireguard': {'address': '10.5.0.1/24'}}) self.assertEqual(r.status_code, 409) data = json.loads(r.data) self.assertIn('remote', data['error']) def test_wg_address_non_overlapping_connected_cell_accepted(self): with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = self._connected r = _put(self.client, {'wireguard': {'address': '10.6.0.1/24'}}) self.assertNotEqual(r.status_code, 409) def test_wg_address_no_connected_cells_accepted(self): with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = [] r = _put(self.client, {'wireguard': {'address': '10.5.0.1/24'}}) self.assertNotEqual(r.status_code, 409) def test_wg_address_missing_prefix_still_returns_400(self): """Format check fires before the conflict check.""" with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = self._connected r = _put(self.client, {'wireguard': {'address': '10.5.0.1'}}) self.assertEqual(r.status_code, 400) # --------------------------------------------------------------------------- # ip_range subnet conflict validation # --------------------------------------------------------------------------- class TestIpRangeConflictValidation(unittest.TestCase): """Changing ip_range to one overlapping a connected cell's vpn_subnet → 409.""" def setUp(self): self.client = _make_client() self._connected = [{'cell_name': 'remote', 'domain': 'remote.cell', 'vpn_subnet': '10.5.0.0/24', 'dns_ip': '10.5.0.1'}] def test_ip_range_overlapping_connected_cell_returns_409(self): with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = self._connected r = _put(self.client, {'ip_range': '10.5.0.0/16'}) self.assertEqual(r.status_code, 409) data = json.loads(r.data) self.assertIn('remote', data['error']) def test_ip_range_non_overlapping_connected_cell_accepted(self): with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = self._connected r = _put(self.client, {'ip_range': '10.6.0.0/16'}) self.assertNotEqual(r.status_code, 409) def test_ip_range_no_connected_cells_accepted(self): with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = [] r = _put(self.client, {'ip_range': '10.5.0.0/16'}) self.assertNotEqual(r.status_code, 409) def test_ip_range_non_rfc1918_still_returns_400(self): """RFC-1918 check fires before the conflict check.""" with patch('app.cell_link_manager') as mock_clm: mock_clm.list_connections.return_value = self._connected r = _put(self.client, {'ip_range': '8.8.8.0/24'}) self.assertEqual(r.status_code, 400) if __name__ == '__main__': unittest.main()