diff --git a/api/app.py b/api/app.py index 99b8fa8..26aa91e 100644 --- a/api/app.py +++ b/api/app.py @@ -317,6 +317,7 @@ def _apply_startup_enforcement(): _cell_subnets = [l['vpn_subnet'] for l in cell_links if l.get('vpn_subnet')] firewall_manager.apply_all_peer_rules(peers, wg_subnet=_wg_subnet, cell_subnets=_cell_subnets) firewall_manager.apply_all_cell_rules(cell_links) + firewall_manager.ensure_forward_stateful() firewall_manager.ensure_cell_api_dnat() # Embed DNAT rules in PostUp so they survive WireGuard interface restarts, # then also apply them immediately for the current session. diff --git a/api/firewall_manager.py b/api/firewall_manager.py index 77773fb..8e8804b 100644 --- a/api/firewall_manager.py +++ b/api/firewall_manager.py @@ -399,6 +399,10 @@ def apply_cell_rules(cell_name: str, vpn_subnet: str, inbound_services: List[str '-p', 'tcp', '--dport', '3000', '-m', 'comment', '--comment', tag, '-j', 'ACCEPT']) + # Ensure reply traffic (e.g. ICMP, TCP ACKs) for connections initiated + # by local peers to this cell is not dropped by the cell's catch-all DROP. + ensure_forward_stateful() + logger.info( f"Applied cell rules for {cell_name} ({vpn_subnet}): " f"inbound={inbound_services} exit_relay={exit_relay}" @@ -422,6 +426,30 @@ def apply_all_cell_rules(cell_links: List[Dict[str, Any]]) -> None: apply_cell_rules(name, subnet, inbound, exit_relay=exit_relay) +def ensure_forward_stateful() -> bool: + """Insert a stateful ESTABLISHED/RELATED ACCEPT at the top of FORWARD. + + Cell rules DROP all traffic from a connected cell's subnet except specific + service ports. Without conntrack, ICMP replies and TCP ACKs for connections + initiated BY local peers to the connected cell are also dropped, making + cross-cell routing (peer → cell → remote cell) broken. + + This rule is inserted once and does not carry a peer/cell comment tag, so it + is never removed by clear_peer_rules or clear_cell_rules. + """ + try: + check = ['-C', 'FORWARD', '-m', 'state', '--state', 'ESTABLISHED,RELATED', '-j', 'ACCEPT'] + if _wg_exec(['iptables'] + check).returncode == 0: + return True # already present + _wg_exec(['iptables', '-I', 'FORWARD', '1', '-m', 'state', + '--state', 'ESTABLISHED,RELATED', '-j', 'ACCEPT']) + logger.info('ensure_forward_stateful: inserted ESTABLISHED,RELATED ACCEPT into FORWARD') + return True + except Exception as e: + logger.error(f'ensure_forward_stateful: {e}') + return False + + def ensure_cell_api_dnat() -> bool: """DNAT wg0:3000 → cell-api:3000 inside cell-wireguard. diff --git a/tests/test_firewall_manager.py b/tests/test_firewall_manager.py index f61d030..afa04db 100644 --- a/tests/test_firewall_manager.py +++ b/tests/test_firewall_manager.py @@ -929,5 +929,56 @@ class TestReconcileStale(unittest.TestCase): self.assertIn('10.0.0.8', cleared) +# --------------------------------------------------------------------------- +# ensure_forward_stateful +# --------------------------------------------------------------------------- + +class TestEnsureForwardStateful(unittest.TestCase): + """ensure_forward_stateful must insert ESTABLISHED,RELATED ACCEPT only once.""" + + def _make_exec(self, already_present=False): + calls = [] + def fake_wg_exec(args): + calls.append(args) + r = MagicMock() + # -C (check) returns 0 if present, 1 if not + if '-C' in args: + r.returncode = 0 if already_present else 1 + else: + r.returncode = 0 + r.stdout = '' + return r + return calls, fake_wg_exec + + def test_inserts_rule_when_not_present(self): + calls, fake = self._make_exec(already_present=False) + with patch.object(firewall_manager, '_wg_exec', side_effect=fake): + result = firewall_manager.ensure_forward_stateful() + self.assertTrue(result) + insert_calls = [c for c in calls if '-I' in c] + self.assertEqual(len(insert_calls), 1) + flat = ' '.join(insert_calls[0]) + self.assertIn('ESTABLISHED,RELATED', flat) + self.assertIn('ACCEPT', flat) + + def test_skips_insert_when_already_present(self): + calls, fake = self._make_exec(already_present=True) + with patch.object(firewall_manager, '_wg_exec', side_effect=fake): + result = firewall_manager.ensure_forward_stateful() + self.assertTrue(result) + insert_calls = [c for c in calls if '-I' in c] + self.assertEqual(len(insert_calls), 0, "Must not insert duplicate rule") + + def test_apply_cell_rules_calls_ensure_forward_stateful(self): + """apply_cell_rules must call ensure_forward_stateful so replies are never dropped.""" + with patch.object(firewall_manager, '_wg_exec', return_value=MagicMock(returncode=0, stdout='')), \ + patch.object(firewall_manager, '_get_caddy_container_ip', return_value='172.20.0.2'), \ + patch.object(firewall_manager, '_get_dns_container_ip', return_value='172.20.0.3'), \ + patch.object(firewall_manager, '_get_cell_api_ip', return_value='172.20.0.10'), \ + patch.object(firewall_manager, 'ensure_forward_stateful') as mock_stateful: + firewall_manager.apply_cell_rules('testcell', '10.0.0.0/24', []) + mock_stateful.assert_called_once() + + if __name__ == '__main__': unittest.main()