diff --git a/api/app.py b/api/app.py index 197f1d7..a13b051 100644 --- a/api/app.py +++ b/api/app.py @@ -268,6 +268,7 @@ def _apply_startup_enforcement(): try: peers = peer_registry.list_peers() cell_links = cell_link_manager.list_connections() + firewall_manager.reconcile_stale_peer_rules(peers) firewall_manager.apply_all_peer_rules(peers) firewall_manager.apply_all_cell_rules(cell_links) firewall_manager.ensure_cell_api_dnat() diff --git a/api/firewall_manager.py b/api/firewall_manager.py index 14efac9..94a7e4c 100644 --- a/api/firewall_manager.py +++ b/api/firewall_manager.py @@ -221,6 +221,42 @@ def apply_all_peer_rules(peers: List[Dict[str, Any]]) -> None: }) +def reconcile_stale_peer_rules(peers: List[Dict[str, Any]]) -> int: + """Remove iptables rules for peer IPs that are no longer in the registry. + + Returns the number of stale IPs cleaned up. + """ + known_ips = set() + for peer in peers: + raw = peer.get('ip', '') + ip = raw.split('/')[0] if raw else '' + if ip: + known_ips.add(ip) + + # Parse pic-peer-* comments from iptables-save to find IPs with live rules + save = _wg_exec(['iptables-save']) + if save.returncode != 0: + return 0 + + # Comment format: pic-peer-A-B-C-D/32 (dots replaced with dashes) + comment_re = re.compile(r'pic-peer-([\d]+-[\d]+-[\d]+-[\d]+)/32') + stale_ips: set = set() + for line in save.stdout.splitlines(): + m = comment_re.search(line) + if m: + ip = m.group(1).replace('-', '.') + if ip not in known_ips: + stale_ips.add(ip) + + for ip in stale_ips: + logger.warning(f"Removing stale iptables rules for deleted peer {ip}") + clear_peer_rules(ip) + + if stale_ips: + logger.info(f"Reconciled {len(stale_ips)} stale peer(s): {sorted(stale_ips)}") + return len(stale_ips) + + # --------------------------------------------------------------------------- # Cell-to-cell firewall rules # --------------------------------------------------------------------------- diff --git a/api/wireguard_manager.py b/api/wireguard_manager.py index cc374ef..bb9449a 100644 --- a/api/wireguard_manager.py +++ b/api/wireguard_manager.py @@ -134,11 +134,11 @@ class WireGuardManager(BaseServiceManager): f'PrivateKey = {keys["private_key"]}\n' f'Address = {address}\n' f'ListenPort = {cfg_port}\n' - f'PostUp = iptables -A FORWARD -i %i -j ACCEPT; ' + f'PostUp = iptables -A FORWARD -i %i -j DROP; ' f'iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE; ' f'{hairpin}' f'sysctl -q net.ipv4.conf.all.rp_filter=0 || true\n' - f'PostDown = iptables -D FORWARD -i %i -j ACCEPT; ' + f'PostDown = iptables -D FORWARD -i %i -j DROP; ' f'iptables -t nat -D POSTROUTING -o eth0 -j MASQUERADE; ' f'{hairpin_down}' f'sysctl -q net.ipv4.conf.all.rp_filter=1 || true\n' diff --git a/tests/test_firewall_manager.py b/tests/test_firewall_manager.py index b386449..bd45dc1 100644 --- a/tests/test_firewall_manager.py +++ b/tests/test_firewall_manager.py @@ -715,5 +715,84 @@ class TestEnsureCellApiDnat(unittest.TestCase): self.assertFalse(result) +# --------------------------------------------------------------------------- +# reconcile_stale_peer_rules +# --------------------------------------------------------------------------- + +class TestReconcileStale(unittest.TestCase): + + def _save_result(self, stdout_text): + r = MagicMock() + r.returncode = 0 + r.stdout = stdout_text + return r + + def test_returns_zero_when_no_rules(self): + with patch.object(firewall_manager, '_wg_exec', return_value=self._save_result('*filter\nCOMMIT\n')): + n = firewall_manager.reconcile_stale_peer_rules([]) + self.assertEqual(n, 0) + + def test_returns_zero_when_all_peers_known(self): + save_out = ( + '*filter\n' + '-A FORWARD -s 10.0.0.2 -m comment --comment "pic-peer-10-0-0-2/32" -j ACCEPT\n' + 'COMMIT\n' + ) + peers = [{'ip': '10.0.0.2'}] + with patch.object(firewall_manager, '_wg_exec', return_value=self._save_result(save_out)): + n = firewall_manager.reconcile_stale_peer_rules(peers) + self.assertEqual(n, 0) + + def test_clears_stale_peer(self): + save_out = ( + '*filter\n' + '-A FORWARD -s 10.0.0.9 -m comment --comment "pic-peer-10-0-0-9/32" -j ACCEPT\n' + 'COMMIT\n' + ) + cleared = [] + with patch.object(firewall_manager, '_wg_exec', return_value=self._save_result(save_out)): + with patch.object(firewall_manager, 'clear_peer_rules', side_effect=cleared.append) as mock_clear: + n = firewall_manager.reconcile_stale_peer_rules([]) + self.assertEqual(n, 1) + mock_clear.assert_called_once_with('10.0.0.9') + + def test_handles_cidr_peer_ip(self): + """Peer IPs stored as 'x.x.x.x/32' should still match.""" + save_out = ( + '*filter\n' + '-A FORWARD -s 10.0.0.5 -m comment --comment "pic-peer-10-0-0-5/32" -j ACCEPT\n' + 'COMMIT\n' + ) + peers = [{'ip': '10.0.0.5/32'}] + with patch.object(firewall_manager, '_wg_exec', return_value=self._save_result(save_out)): + with patch.object(firewall_manager, 'clear_peer_rules') as mock_clear: + n = firewall_manager.reconcile_stale_peer_rules(peers) + self.assertEqual(n, 0) + mock_clear.assert_not_called() + + def test_returns_zero_on_iptables_save_failure(self): + fail_r = MagicMock() + fail_r.returncode = 1 + fail_r.stdout = '' + with patch.object(firewall_manager, '_wg_exec', return_value=fail_r): + n = firewall_manager.reconcile_stale_peer_rules([]) + self.assertEqual(n, 0) + + def test_multiple_stale_ips_all_cleared(self): + save_out = ( + '*filter\n' + '-A FORWARD -s 10.0.0.7 -m comment --comment "pic-peer-10-0-0-7/32" -j DROP\n' + '-A FORWARD -s 10.0.0.8 -m comment --comment "pic-peer-10-0-0-8/32" -j ACCEPT\n' + 'COMMIT\n' + ) + cleared = [] + with patch.object(firewall_manager, '_wg_exec', return_value=self._save_result(save_out)): + with patch.object(firewall_manager, 'clear_peer_rules', side_effect=cleared.append): + n = firewall_manager.reconcile_stale_peer_rules([]) + self.assertEqual(n, 2) + self.assertIn('10.0.0.7', cleared) + self.assertIn('10.0.0.8', cleared) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_wireguard_manager.py b/tests/test_wireguard_manager.py index 782d7b3..a7e87d5 100644 --- a/tests/test_wireguard_manager.py +++ b/tests/test_wireguard_manager.py @@ -480,9 +480,10 @@ class TestWireGuardSysctlAndPortCheck(unittest.TestCase): cfg = self.wg.generate_config() self.assertIn('MASQUERADE', cfg) - def test_generate_config_has_forward_rule(self): + def test_generate_config_has_forward_drop_rule(self): cfg = self.wg.generate_config() - self.assertIn('FORWARD -i %i -j ACCEPT', cfg) + self.assertIn('FORWARD -i %i -j DROP', cfg) + self.assertNotIn('FORWARD -i %i -j ACCEPT', cfg) # ── check_port_open ─────────────────────────────────────────────────────── diff --git a/tests/test_wireguard_vpn_routing.py b/tests/test_wireguard_vpn_routing.py index 777deb2..8dda35d 100644 --- a/tests/test_wireguard_vpn_routing.py +++ b/tests/test_wireguard_vpn_routing.py @@ -46,8 +46,13 @@ def _make_wg(tmp: str) -> WireGuardManager: class TestInternetForwardingRules(unittest.TestCase): """ Verify that generate_config() emits the exact iptables rules required for - 'internet through VPN': MASQUERADE on eth0 (outbound NAT) and FORWARD ACCEPT - on the wg0 interface. Missing either rule means VPN clients get no internet. + 'internet through VPN': MASQUERADE on eth0 (outbound NAT) and a catch-all + FORWARD DROP on the wg0 interface. + + The catch-all is DROP (not ACCEPT) so that only per-peer rules inserted at + chain position 1 via apply_peer_rules() can forward traffic. An ACCEPT + catch-all would allow any WireGuard-connected client full internet access + even if they have no entry in peers.json. """ def setUp(self): @@ -61,10 +66,11 @@ class TestInternetForwardingRules(unittest.TestCase): cfg = self.wg.generate_config() self.assertIn('POSTROUTING -o eth0 -j MASQUERADE', cfg) - def test_postup_has_forward_accept_on_wg_interface(self): - """FORWARD ACCEPT allows packets from the WireGuard interface through the kernel.""" + def test_postup_has_forward_drop_on_wg_interface(self): + """Catch-all DROP blocks unconfigured WG clients; per-peer rules inserted above it allow known peers.""" cfg = self.wg.generate_config() - self.assertIn('FORWARD -i %i -j ACCEPT', cfg) + self.assertIn('FORWARD -i %i -j DROP', cfg) + self.assertNotIn('FORWARD -i %i -j ACCEPT', cfg) def test_postdown_removes_masquerade_rule(self): """PostDown must mirror PostUp so rules are cleaned up when the tunnel goes down.""" @@ -73,7 +79,7 @@ class TestInternetForwardingRules(unittest.TestCase): def test_postdown_removes_forward_rule(self): cfg = self.wg.generate_config() - self.assertIn('FORWARD -i %i -j ACCEPT', cfg.split('PostDown')[1]) + self.assertIn('FORWARD -i %i -j DROP', cfg.split('PostDown')[1]) def test_postup_and_postdown_are_present(self): """Both PostUp and PostDown must exist — PostUp without PostDown leaks rules."""