#!/usr/bin/env python3 """ Tests for firewall_manager — per-peer iptables rule generation and DNS ACL logic. All docker exec calls are mocked so tests run without a live Docker environment. """ import sys import os import tempfile import shutil import unittest from unittest.mock import patch, call, MagicMock from pathlib import Path api_dir = Path(__file__).parent.parent / 'api' sys.path.insert(0, str(api_dir)) import firewall_manager def _make_peer(ip, internet=True, services=None, peers=True): if services is None: services = list(firewall_manager.SERVICE_IPS.keys()) return {'ip': ip, 'internet_access': internet, 'service_access': services, 'peer_access': peers} # --------------------------------------------------------------------------- # _peer_comment # --------------------------------------------------------------------------- class TestPeerComment(unittest.TestCase): def test_dots_replaced_with_dashes(self): # Comment format now includes /32 suffix to prevent substring matches # (e.g. pic-peer-10-0-0-1/32 is not a prefix of pic-peer-10-0-0-10/32) self.assertEqual(firewall_manager._peer_comment('10.0.0.2'), 'pic-peer-10-0-0-2/32') def test_different_ip(self): self.assertEqual(firewall_manager._peer_comment('192.168.1.100'), 'pic-peer-192-168-1-100/32') # --------------------------------------------------------------------------- # _build_acl_block # --------------------------------------------------------------------------- class TestBuildAclBlock(unittest.TestCase): def test_empty_returns_empty_string(self): self.assertEqual(firewall_manager._build_acl_block({}), '') def test_no_blocked_peers_returns_empty(self): blocked = {s: [] for s in firewall_manager.SERVICE_IPS} self.assertEqual(firewall_manager._build_acl_block(blocked), '') def test_blocked_peer_appears_in_acl(self): blocked = {'calendar': ['10.0.0.5'], 'files': [], 'mail': [], 'webdav': []} result = firewall_manager._build_acl_block(blocked) self.assertIn('acl calendar.cell.', result) self.assertIn('block net 10.0.0.5/32', result) self.assertIn('allow net 0.0.0.0/0', result) def test_unknown_service_skipped(self): blocked = {'nonexistent': ['10.0.0.2']} result = firewall_manager._build_acl_block(blocked) self.assertEqual(result, '') def test_multiple_peers_blocked_from_same_service(self): blocked = {'mail': ['10.0.0.2', '10.0.0.3'], 'calendar': [], 'files': [], 'webdav': []} result = firewall_manager._build_acl_block(blocked) self.assertEqual(result.count('block net'), 2) self.assertIn('10.0.0.2/32', result) self.assertIn('10.0.0.3/32', result) # --------------------------------------------------------------------------- # generate_corefile # --------------------------------------------------------------------------- class TestGenerateCorefile(unittest.TestCase): def setUp(self): self.tmp = tempfile.mkdtemp() self.path = os.path.join(self.tmp, 'Corefile') def tearDown(self): shutil.rmtree(self.tmp) def test_creates_corefile(self): firewall_manager.generate_corefile([], self.path) self.assertTrue(os.path.exists(self.path)) def test_contains_forward_and_cache(self): firewall_manager.generate_corefile([], self.path) content = open(self.path).read() self.assertIn('forward . 8.8.8.8', content) self.assertIn('cache', content) self.assertIn('cell {', content) def test_no_blocked_services_no_acl_block(self): peers = [_make_peer('10.0.0.2', internet=True, services=list(firewall_manager.SERVICE_IPS.keys()))] firewall_manager.generate_corefile(peers, self.path) content = open(self.path).read() self.assertNotIn('block net', content) def test_blocked_service_generates_acl(self): peers = [_make_peer('10.0.0.3', internet=False, services=['calendar'])] firewall_manager.generate_corefile(peers, self.path) content = open(self.path).read() # files/mail/webdav are blocked for this peer self.assertIn('block net 10.0.0.3/32', content) def test_peer_with_all_services_allowed_no_acl(self): peers = [_make_peer('10.0.0.2', services=list(firewall_manager.SERVICE_IPS.keys()))] firewall_manager.generate_corefile(peers, self.path) self.assertNotIn('block net', open(self.path).read()) def test_returns_false_on_write_error(self): result = firewall_manager.generate_corefile([], '/nonexistent/path/Corefile') self.assertFalse(result) # --------------------------------------------------------------------------- # generate_corefile with cell_links # --------------------------------------------------------------------------- class TestGenerateCorefileWithCellLinks(unittest.TestCase): def setUp(self): self.tmp = tempfile.mkdtemp() self.path = os.path.join(self.tmp, 'Corefile') def tearDown(self): shutil.rmtree(self.tmp) def _content(self): return open(self.path).read() def test_cell_links_none_produces_no_forwarding_stanzas(self): """Default (None) produces no extra forwarding blocks beyond the primary zone.""" firewall_manager.generate_corefile([], self.path, cell_links=None) content = self._content() # The only 'forward' line should be the default internet forwarder forward_lines = [l for l in content.splitlines() if 'forward' in l] self.assertEqual(len(forward_lines), 1) self.assertIn('8.8.8.8', forward_lines[0]) def test_cell_links_empty_list_produces_no_extra_stanzas(self): """An empty cell_links list produces no extra forwarding blocks.""" firewall_manager.generate_corefile([], self.path, cell_links=[]) content = self._content() forward_lines = [l for l in content.splitlines() if 'forward' in l] self.assertEqual(len(forward_lines), 1) self.assertIn('8.8.8.8', forward_lines[0]) def test_single_cell_link_produces_forwarding_block(self): """One cell link produces one forwarding stanza with correct domain and dns_ip.""" cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.1.0.1'}] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._content() self.assertIn('remote.cell {', content) self.assertIn('forward . 10.1.0.1', content) self.assertIn('cache', content) def test_multiple_cell_links_produce_multiple_forwarding_blocks(self): """Multiple cell links produce one stanza each.""" cell_links = [ {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, ] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._content() self.assertIn('alpha.cell {', content) self.assertIn('forward . 10.1.0.1', content) self.assertIn('beta.cell {', content) self.assertIn('forward . 10.2.0.1', content) def test_cell_links_do_not_overwrite_peer_acls(self): """Cell link stanzas are appended; peer ACLs in the primary zone survive.""" peers = [_make_peer('10.0.0.3', services=['calendar'])] cell_links = [{'domain': 'other.cell', 'dns_ip': '10.99.0.1'}] firewall_manager.generate_corefile(peers, self.path, cell_links=cell_links) content = self._content() self.assertIn('block net 10.0.0.3/32', content) self.assertIn('other.cell {', content) self.assertIn('forward . 10.99.0.1', content) def test_link_with_missing_domain_is_skipped(self): """A cell_link entry with no domain key is silently skipped.""" cell_links = [{'dns_ip': '10.1.0.1'}] # no 'domain' firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._content() # Only the default internet forwarder forward_lines = [l for l in content.splitlines() if 'forward' in l] self.assertEqual(len(forward_lines), 1) def test_link_with_missing_dns_ip_is_skipped(self): """A cell_link entry with no dns_ip key is silently skipped.""" cell_links = [{'domain': 'nope.cell'}] # no 'dns_ip' firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._content() self.assertNotIn('nope.cell', content) # --------------------------------------------------------------------------- # apply_peer_rules — iptables call verification # --------------------------------------------------------------------------- class TestApplyPeerRules(unittest.TestCase): """Verify correct iptables calls for full-internet vs split-tunnel peers.""" def _run_apply(self, peer_ip, settings): calls_made = [] def fake_wg_exec(args): calls_made.append(args) m = MagicMock() m.returncode = 0 m.stdout = '' return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec): firewall_manager.apply_peer_rules(peer_ip, settings) return calls_made def test_full_internet_peer_gets_accept_rule(self): calls = self._run_apply('10.0.0.2', {'internet_access': True, 'service_access': list(firewall_manager.SERVICE_IPS.keys()), 'peer_access': True}) iptables_calls = [c for c in calls if 'iptables' in c] targets = [c[c.index('-j') + 1] for c in iptables_calls if '-j' in c] # Full-internet peer: only ACCEPT rules (no DROP except iptables-restore clears) self.assertNotIn('DROP', targets) self.assertIn('ACCEPT', targets) def test_no_internet_peer_gets_drop_rule(self): calls = self._run_apply('10.0.0.3', {'internet_access': False, 'service_access': list(firewall_manager.SERVICE_IPS.keys()), 'peer_access': True}) iptables_calls = [c for c in calls if 'iptables' in c] targets = [c[c.index('-j') + 1] for c in iptables_calls if '-j' in c] self.assertIn('DROP', targets) self.assertIn('ACCEPT', targets) def test_service_access_restriction_generates_drop(self): calls = self._run_apply('10.0.0.4', {'internet_access': False, 'service_access': ['calendar'], 'peer_access': True}) iptables_calls = [c for c in calls if 'iptables' in c] # files/mail/webdav should be DROPped, calendar ACCEPTed targets_with_ips = [ (c[c.index('-d') + 1], c[c.index('-j') + 1]) for c in iptables_calls if '-d' in c and '-j' in c ] svc_rules = {ip: t for ip, t in targets_with_ips if ip in firewall_manager.SERVICE_IPS.values()} calendar_ip = firewall_manager.SERVICE_IPS['calendar'] files_ip = firewall_manager.SERVICE_IPS['files'] self.assertEqual(svc_rules.get(calendar_ip), 'ACCEPT') self.assertEqual(svc_rules.get(files_ip), 'DROP') def test_all_rules_tagged_with_peer_comment(self): calls = self._run_apply('10.0.0.2', {'internet_access': True, 'service_access': list(firewall_manager.SERVICE_IPS.keys()), 'peer_access': True}) iptables_calls = [c for c in calls if 'iptables' in c] comment = firewall_manager._peer_comment('10.0.0.2') for c in iptables_calls: if '-I' in c: # only insertion rules need the comment self.assertIn(comment, c, f"Rule missing comment tag: {c}") def test_peer_with_no_peer_access_gets_drop_for_vpn_subnet(self): calls = self._run_apply('10.0.0.5', {'internet_access': True, 'service_access': list(firewall_manager.SERVICE_IPS.keys()), 'peer_access': False}) iptables_calls = [c for c in calls if 'iptables' in c] vpn_rules = [c for c in iptables_calls if '-d' in c and '10.0.0.0/24' in c] self.assertTrue(vpn_rules, "Expected a rule for 10.0.0.0/24") for c in vpn_rules: self.assertIn('DROP', c) # --------------------------------------------------------------------------- # apply_all_peer_rules # --------------------------------------------------------------------------- class TestApplyAllPeerRules(unittest.TestCase): def test_calls_apply_per_peer(self): peers = [_make_peer('10.0.0.2'), _make_peer('10.0.0.3', internet=False)] with patch.object(firewall_manager, 'ensure_caddy_virtual_ips', return_value=True), \ patch.object(firewall_manager, 'apply_peer_rules', return_value=True) as mock_apply: firewall_manager.apply_all_peer_rules(peers) self.assertEqual(mock_apply.call_count, 2) called_ips = {c.args[0] for c in mock_apply.call_args_list} self.assertEqual(called_ips, {'10.0.0.2', '10.0.0.3'}) def test_peer_without_ip_is_skipped(self): peers = [{'internet_access': True}, _make_peer('10.0.0.2')] with patch.object(firewall_manager, 'ensure_caddy_virtual_ips', return_value=True), \ patch.object(firewall_manager, 'apply_peer_rules', return_value=True) as mock_apply: firewall_manager.apply_all_peer_rules(peers) self.assertEqual(mock_apply.call_count, 1) # --------------------------------------------------------------------------- # clear_peer_rules # --------------------------------------------------------------------------- class TestClearPeerRules(unittest.TestCase): def test_removes_only_matching_comment_lines(self): save_output = ( '*filter\n' ':INPUT ACCEPT [0:0]\n' ':FORWARD ACCEPT [0:0]\n' '-A FORWARD -s 10.0.0.2 -m comment --comment "pic-peer-10-0-0-2/32" -j ACCEPT\n' '-A FORWARD -s 10.0.0.3 -m comment --comment "pic-peer-10-0-0-3/32" -j DROP\n' 'COMMIT\n' ) restored = [] def fake_wg_exec(args): m = MagicMock() m.returncode = 0 if args == ['iptables-save']: m.stdout = save_output return m def fake_restore(cmd, input, **kwargs): restored.append(input) m = MagicMock() m.returncode = 0 return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec), \ patch('subprocess.run', side_effect=fake_restore): firewall_manager.clear_peer_rules('10.0.0.2') self.assertEqual(len(restored), 1) restored_content = restored[0] self.assertNotIn('pic-peer-10-0-0-2/32', restored_content) self.assertIn('pic-peer-10-0-0-3/32', restored_content) def test_no_op_when_no_matching_rules(self): save_output = '*filter\n:FORWARD ACCEPT [0:0]\nCOMMIT\n' def fake_wg_exec(args): m = MagicMock() m.returncode = 0 m.stdout = save_output return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec), \ patch('subprocess.run') as mock_restore: firewall_manager.clear_peer_rules('10.0.0.99') mock_restore.assert_not_called() # --------------------------------------------------------------------------- # update_service_ips # --------------------------------------------------------------------------- class TestUpdateServiceIps(unittest.TestCase): def tearDown(self): # Restore default SERVICE_IPS after each test firewall_manager.update_service_ips('172.20.0.0/16') def test_default_ips_are_172_20(self): self.assertEqual(firewall_manager.SERVICE_IPS['calendar'], '172.20.0.21') self.assertEqual(firewall_manager.SERVICE_IPS['webdav'], '172.20.0.24') def test_update_changes_all_virtual_ips(self): firewall_manager.update_service_ips('10.0.0.0/24') self.assertEqual(firewall_manager.SERVICE_IPS['calendar'], '10.0.0.21') self.assertEqual(firewall_manager.SERVICE_IPS['files'], '10.0.0.22') self.assertEqual(firewall_manager.SERVICE_IPS['mail'], '10.0.0.23') self.assertEqual(firewall_manager.SERVICE_IPS['webdav'], '10.0.0.24') def test_update_replaces_not_extends(self): firewall_manager.update_service_ips('10.0.0.0/24') # Should only have the four virtual-IP keys self.assertEqual(set(firewall_manager.SERVICE_IPS.keys()), {'calendar', 'files', 'mail', 'webdav'}) def test_apply_peer_rules_uses_updated_ips(self): firewall_manager.update_service_ips('10.0.0.0/24') called_with = [] def fake_wg_exec(args): called_with.append(args) m = MagicMock() m.returncode = 1 # simulate rule-doesn't-exist → _ensure_rule inserts return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec), \ patch.object(firewall_manager, 'clear_peer_rules'): firewall_manager.apply_peer_rules('10.0.0.5', { 'internet_access': True, 'service_access': ['calendar'], 'peer_access': True, }) iptables_calls = [c for c in called_with if c and c[0] == 'iptables'] dest_ips = [c[c.index('-d') + 1] for c in iptables_calls if '-d' in c] # calendar vIP should now be 10.0.0.21 self.assertIn('10.0.0.21', dest_ips) # old IP must not appear self.assertNotIn('172.20.0.21', dest_ips) # --------------------------------------------------------------------------- # TestCellRules # --------------------------------------------------------------------------- class TestCellRules(unittest.TestCase): """Tests for apply_cell_rules, clear_cell_rules, _cell_tag, and apply_all_cell_rules.""" # ── helpers ─────────────────────────────────────────────────────────────── _FAKE_API_IP = '172.20.0.10' def _capture_apply(self, cell_name, vpn_subnet, inbound_services): """Run apply_cell_rules with _wg_exec and _get_cell_api_ip mocked.""" calls_made = [] def fake_wg_exec(args): calls_made.append(args) m = MagicMock() m.returncode = 0 m.stdout = '' return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec), \ patch.object(firewall_manager, '_get_cell_api_ip', return_value=self._FAKE_API_IP): firewall_manager.apply_cell_rules(cell_name, vpn_subnet, inbound_services) return [c for c in calls_made if 'iptables' in c] def _targets_for_dest(self, iptables_calls, dest_ip): """Return list of -j targets where -d matches dest_ip.""" targets = [] for c in iptables_calls: if '-d' in c and dest_ip in c and '-j' in c: targets.append(c[c.index('-j') + 1]) return targets # ── _cell_tag ───────────────────────────────────────────────────────────── def test_cell_tag_sanitises_spaces_and_punctuation(self): """_cell_tag replaces non-alphanumeric chars with dashes.""" tag = firewall_manager._cell_tag('my cell!') self.assertTrue(tag.startswith('pic-cell-')) self.assertNotIn(' ', tag) self.assertNotIn('!', tag) def test_cell_tag_lowercase(self): """_cell_tag lowercases the cell name.""" tag = firewall_manager._cell_tag('Office') self.assertIn('office', tag) def test_cell_tag_has_pic_cell_prefix(self): """_cell_tag always starts with 'pic-cell-'.""" self.assertTrue(firewall_manager._cell_tag('remote').startswith('pic-cell-')) def test_cell_tag_distinct_from_peer_tag(self): """A cell tag must not equal the peer comment for the same string.""" cell_tag = firewall_manager._cell_tag('10.0.0.2') peer_tag = firewall_manager._peer_comment('10.0.0.2') self.assertNotEqual(cell_tag, peer_tag) # ── apply_cell_rules — catch-all DROP ───────────────────────────────────── def test_apply_cell_rules_sends_catch_all_drop(self): """apply_cell_rules always inserts a DROP for the entire vpn_subnet.""" calls = self._capture_apply('office', '10.0.1.0/24', ['calendar']) subnet_drops = [ c for c in calls if '-s' in c and '10.0.1.0/24' in c and '-j' in c and c[c.index('-j') + 1] == 'DROP' and '-d' not in c # catch-all has no destination ] self.assertTrue(subnet_drops, "Expected a catch-all DROP rule for the subnet") def test_apply_cell_rules_sends_accept_for_allowed_service(self): """apply_cell_rules inserts ACCEPT for the calendar VIP when calendar is in inbound.""" calls = self._capture_apply('office', '10.0.1.0/24', ['calendar']) calendar_ip = firewall_manager.SERVICE_IPS['calendar'] calendar_targets = self._targets_for_dest(calls, calendar_ip) self.assertIn('ACCEPT', calendar_targets) def test_apply_cell_rules_sends_drop_for_disallowed_service(self): """apply_cell_rules inserts DROP for a service not in inbound_services.""" calls = self._capture_apply('office', '10.0.1.0/24', ['calendar']) files_ip = firewall_manager.SERVICE_IPS['files'] files_targets = self._targets_for_dest(calls, files_ip) self.assertIn('DROP', files_targets) def test_apply_cell_rules_accepts_api_sync_traffic(self): """apply_cell_rules inserts ACCEPT for cell-api:3000 so permission-sync pushes pass.""" calls = self._capture_apply('office', '10.0.1.0/24', []) api_ip = self._FAKE_API_IP api_accepts = [ c for c in calls if '-s' in c and '10.0.1.0/24' in c and '-d' in c and api_ip in c and '--dport' in c and '3000' in c and '-j' in c and c[c.index('-j') + 1] == 'ACCEPT' ] self.assertTrue(api_accepts, 'Expected an ACCEPT rule for cell-api:3000') def test_apply_cell_rules_api_sync_accept_before_catchall_drop(self): """The API-sync ACCEPT must be inserted after service rules so it ends up above DROP.""" insertion_order = [] def fake_wg_exec(args): if '-I' in args and 'FORWARD' in args: if '-j' in args: insertion_order.append(args[args.index('-j') + 1]) m = MagicMock(); m.returncode = 0; m.stdout = ''; return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec), \ patch.object(firewall_manager, '_get_cell_api_ip', return_value='172.20.0.10'): firewall_manager.apply_cell_rules('office', '10.0.1.0/24', []) # The API-sync ACCEPT must be the LAST -I FORWARD insertion so it sits at position 1 self.assertTrue(insertion_order, 'Expected at least one FORWARD rule inserted') self.assertEqual(insertion_order[-1], 'ACCEPT', f'Last -I FORWARD insertion must be ACCEPT (got {insertion_order})') # ── apply_cell_rules — empty inbound (all-deny) ─────────────────────────── def test_apply_cell_rules_empty_inbound_all_drop(self): """With inbound_services=[], all per-service rules are DROP.""" calls = self._capture_apply('office', '10.0.1.0/24', []) for service, svc_ip in firewall_manager.SERVICE_IPS.items(): svc_targets = self._targets_for_dest(calls, svc_ip) self.assertTrue(svc_targets, f"Expected at least one rule for {service} ({svc_ip})") self.assertNotIn('ACCEPT', svc_targets, f"{service} should be DROP when not in inbound_services") # ── apply_cell_rules — all inbound (all-accept) ─────────────────────────── def test_apply_cell_rules_all_inbound_all_accept(self): """With all four services in inbound, all per-service rules are ACCEPT.""" all_services = list(firewall_manager.SERVICE_IPS.keys()) calls = self._capture_apply('office', '10.0.1.0/24', all_services) for service, svc_ip in firewall_manager.SERVICE_IPS.items(): svc_targets = self._targets_for_dest(calls, svc_ip) self.assertIn('ACCEPT', svc_targets, f"{service} should be ACCEPT when in inbound_services") # ── apply_cell_rules — all rules tagged ─────────────────────────────────── def test_apply_cell_rules_all_rules_tagged_with_cell_tag(self): """Every insertion rule must carry the cell's comment tag.""" calls = self._capture_apply('office', '10.0.1.0/24', ['calendar']) tag = firewall_manager._cell_tag('office') for c in calls: if '-I' in c: self.assertIn(tag, c, f"Rule missing cell tag: {c}") # ── clear_cell_rules — noop when no matching rules ──────────────────────── def test_clear_cell_rules_noop_when_no_rules(self): """When iptables-save returns no pic-cell-office lines, iptables-restore is NOT called.""" save_output = '*filter\n:FORWARD ACCEPT [0:0]\nCOMMIT\n' def fake_wg_exec(args): m = MagicMock() m.returncode = 0 m.stdout = save_output return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec), \ patch('subprocess.run') as mock_restore: firewall_manager.clear_cell_rules('office') mock_restore.assert_not_called() def test_clear_cell_rules_removes_tagged_lines(self): """clear_cell_rules removes lines carrying the cell tag and keeps others.""" tag = firewall_manager._cell_tag('office') save_output = ( '*filter\n' ':FORWARD ACCEPT [0:0]\n' f'-A FORWARD -s 10.0.1.0/24 -m comment --comment "{tag}" -j DROP\n' '-A FORWARD -s 10.0.0.2 -m comment --comment "pic-peer-10-0-0-2/32" -j ACCEPT\n' 'COMMIT\n' ) restored = [] def fake_wg_exec(args): m = MagicMock() m.returncode = 0 if args == ['iptables-save']: m.stdout = save_output return m def fake_restore(cmd, input, **kwargs): restored.append(input) m = MagicMock() m.returncode = 0 return m with patch.object(firewall_manager, '_wg_exec', side_effect=fake_wg_exec), \ patch('subprocess.run', side_effect=fake_restore): firewall_manager.clear_cell_rules('office') self.assertEqual(len(restored), 1) content = restored[0] self.assertNotIn(tag, content) # peer rule for a different entity must survive self.assertIn('pic-peer-10-0-0-2/32', content) # ── apply_all_cell_rules ────────────────────────────────────────────────── def test_apply_all_cell_rules_calls_apply_for_each(self): """apply_all_cell_rules calls apply_cell_rules once per link with correct args.""" cell_links = [ { 'cell_name': 'office', 'vpn_subnet': '10.1.0.0/24', 'permissions': {'inbound': {'calendar': True, 'files': False, 'mail': False, 'webdav': False}, 'outbound': {}}, }, { 'cell_name': 'cabin', 'vpn_subnet': '10.2.0.0/24', 'permissions': {'inbound': {'calendar': False, 'files': True, 'mail': False, 'webdav': False}, 'outbound': {}}, }, ] with patch.object(firewall_manager, 'apply_cell_rules', return_value=True) as mock_apply: firewall_manager.apply_all_cell_rules(cell_links) self.assertEqual(mock_apply.call_count, 2) call_kwargs = {c.args[0]: c.args for c in mock_apply.call_args_list} self.assertIn('office', call_kwargs) self.assertIn('cabin', call_kwargs) office_args = call_kwargs['office'] self.assertEqual(office_args[1], '10.1.0.0/24') self.assertIn('calendar', office_args[2]) self.assertNotIn('files', office_args[2]) def test_apply_all_cell_rules_skips_links_with_missing_fields(self): """Links without cell_name or vpn_subnet are silently skipped.""" cell_links = [ {'vpn_subnet': '10.1.0.0/24'}, # no cell_name {'cell_name': 'broken'}, # no vpn_subnet {'cell_name': 'office', 'vpn_subnet': '10.3.0.0/24', 'permissions': {'inbound': {}, 'outbound': {}}}, ] with patch.object(firewall_manager, 'apply_cell_rules', return_value=True) as mock_apply: firewall_manager.apply_all_cell_rules(cell_links) # Only the complete entry should be processed self.assertEqual(mock_apply.call_count, 1) self.assertEqual(mock_apply.call_args.args[0], 'office') class TestEnsureCellApiDnat(unittest.TestCase): """Tests for ensure_cell_api_dnat — DNAT wg0:3000 → cell-api:3000.""" def _wg_exec_no_existing_rules(self, args): r = MagicMock() r.returncode = 1 if '-C' in args else 0 # -C = check: fail = not present r.stdout = '' r.stderr = '' return r def _wg_exec_all_rules_exist(self, args): r = MagicMock() r.returncode = 0 # -C succeeds = rule already present r.stdout = '' return r def _inspect_ok(self, api_ip='172.20.0.10'): r = MagicMock() r.returncode = 0 r.stdout = api_ip return r def test_dnat_rules_added_when_not_present(self): with patch.object(firewall_manager, '_run', return_value=self._inspect_ok()), \ patch.object(firewall_manager, '_wg_exec', side_effect=self._wg_exec_no_existing_rules) as wg_mock: result = firewall_manager.ensure_cell_api_dnat() self.assertTrue(result) calls_args = [c.args[0] for c in wg_mock.call_args_list] dnat_adds = [a for a in calls_args if 'DNAT' in a and '-A' in a] self.assertTrue(len(dnat_adds) >= 1, 'DNAT -A rule must be added') def test_dnat_skipped_if_already_present(self): with patch.object(firewall_manager, '_run', return_value=self._inspect_ok()), \ patch.object(firewall_manager, '_wg_exec', side_effect=self._wg_exec_all_rules_exist) as wg_mock: result = firewall_manager.ensure_cell_api_dnat() self.assertTrue(result) calls_args = [c.args[0] for c in wg_mock.call_args_list] add_calls = [a for a in calls_args if '-A' in a or '-I' in a] self.assertEqual(len(add_calls), 0, 'No rules should be added when they already exist') def test_returns_false_when_cell_api_not_found(self): r = MagicMock() r.returncode = 0 r.stdout = '' with patch.object(firewall_manager, '_run', return_value=r): result = firewall_manager.ensure_cell_api_dnat() self.assertFalse(result) def test_returns_false_on_exception(self): with patch.object(firewall_manager, '_run', side_effect=RuntimeError('docker gone')): result = firewall_manager.ensure_cell_api_dnat() self.assertFalse(result) if __name__ == '__main__': unittest.main()