#!/usr/bin/env python3 """ Tests for cell-to-cell DNS forwarding integration. Covers: - generate_corefile() with cell_links entries - apply_all_dns_rules() passing cell_links through to generate_corefile() - Correct domain/dns_ip values in the emitted forwarding stanza - Validation: invalid characters in domain are rejected by add_cell_dns_forward() """ import sys import os import tempfile import shutil import unittest from unittest.mock import patch, MagicMock, call from pathlib import Path api_dir = Path(__file__).parent.parent / 'api' sys.path.insert(0, str(api_dir)) import firewall_manager # --------------------------------------------------------------------------- # generate_corefile() with cell_links # --------------------------------------------------------------------------- class TestGenerateCorefileOneLink(unittest.TestCase): """generate_corefile() with a single cell link produces the right stanza.""" def setUp(self): self.tmp = tempfile.mkdtemp() self.path = os.path.join(self.tmp, 'Corefile') def tearDown(self): shutil.rmtree(self.tmp) def _read(self): return open(self.path).read() def test_forwarding_block_present(self): cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() self.assertIn('remote.cell {', content) def test_correct_dns_ip_in_forward_directive(self): cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() self.assertIn('forward . 10.5.0.1', content) def test_cache_directive_present_in_forwarding_block(self): cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() # 'cache' must appear in the forwarding block (after the primary zone block) idx_primary = content.index('remote.cell {') self.assertIn('cache', content[idx_primary:]) def test_log_directive_present_in_forwarding_block(self): cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() idx_primary = content.index('remote.cell {') self.assertIn('log', content[idx_primary:]) def test_forwarding_block_appears_after_primary_zone(self): """The cell link stanza must appear after the primary zone block, not inside it.""" cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() # Primary zone ends with its closing brace; remote.cell block follows idx_primary_zone = content.index('cell {') idx_forward_block = content.index('remote.cell {') self.assertGreater(idx_forward_block, idx_primary_zone) class TestGenerateCorefileMultipleLinks(unittest.TestCase): """generate_corefile() with multiple cell links produces one stanza each.""" def setUp(self): self.tmp = tempfile.mkdtemp() self.path = os.path.join(self.tmp, 'Corefile') def tearDown(self): shutil.rmtree(self.tmp) def _read(self): return open(self.path).read() def test_all_domains_present(self): cell_links = [ {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, {'domain': 'gamma.cell', 'dns_ip': '10.3.0.1'}, ] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() self.assertIn('alpha.cell {', content) self.assertIn('beta.cell {', content) self.assertIn('gamma.cell {', content) def test_all_dns_ips_present(self): cell_links = [ {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, ] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() self.assertIn('forward . 10.1.0.1', content) self.assertIn('forward . 10.2.0.1', content) def test_stanza_count_matches_link_count(self): """Each valid link contributes exactly one forwarding stanza.""" cell_links = [ {'domain': 'a.cell', 'dns_ip': '10.1.0.1'}, {'domain': 'b.cell', 'dns_ip': '10.2.0.1'}, ] firewall_manager.generate_corefile([], self.path, cell_links=cell_links) content = self._read() # Count occurrences of 'forward .' — one for default, one per cell link count = content.count('forward .') self.assertEqual(count, 3) # 1 default + 2 cell links # --------------------------------------------------------------------------- # apply_all_dns_rules() passes cell_links through to generate_corefile() # --------------------------------------------------------------------------- class TestApplyAllDnsRulesPassesCellLinks(unittest.TestCase): """apply_all_dns_rules() must forward the cell_links argument to generate_corefile().""" def test_cell_links_forwarded(self): cell_links = [{'domain': 'x.cell', 'dns_ip': '10.9.0.1'}] with patch.object(firewall_manager, 'generate_corefile', return_value=True) as mock_gen, \ patch.object(firewall_manager, 'reload_coredns', return_value=True): firewall_manager.apply_all_dns_rules( peers=[], corefile_path='/tmp/fake_Corefile', domain='cell', cell_links=cell_links, ) mock_gen.assert_called_once_with( [], '/tmp/fake_Corefile', 'cell', cell_links ) def test_cell_links_none_forwarded_as_none(self): with patch.object(firewall_manager, 'generate_corefile', return_value=True) as mock_gen, \ patch.object(firewall_manager, 'reload_coredns', return_value=True): firewall_manager.apply_all_dns_rules( peers=[], corefile_path='/tmp/fake_Corefile', domain='cell', cell_links=None, ) mock_gen.assert_called_once_with([], '/tmp/fake_Corefile', 'cell', None) def test_reload_called_on_success(self): with patch.object(firewall_manager, 'generate_corefile', return_value=True), \ patch.object(firewall_manager, 'reload_coredns', return_value=True) as mock_reload: firewall_manager.apply_all_dns_rules([], '/tmp/f', cell_links=None) mock_reload.assert_called_once() def test_reload_not_called_on_failure(self): with patch.object(firewall_manager, 'generate_corefile', return_value=False), \ patch.object(firewall_manager, 'reload_coredns') as mock_reload: firewall_manager.apply_all_dns_rules([], '/tmp/f', cell_links=None) mock_reload.assert_not_called() # --------------------------------------------------------------------------- # Domain validation in add_cell_dns_forward() (via network_manager) # --------------------------------------------------------------------------- class TestAddCellDnsForwardValidation(unittest.TestCase): """ add_cell_dns_forward() must reject malformed domains/IPs without writing the Corefile or calling apply_all_dns_rules(). """ def _get_network_manager(self, tmp_dir): """Construct a minimal NetworkManager with test directories.""" # We import here so the test file doesn't hard-fail if network_manager # has an import-time dependency that's unavailable in CI. try: from network_manager import NetworkManager except ImportError as e: self.skipTest(f'NetworkManager import failed: {e}') os.makedirs(os.path.join(tmp_dir, 'dns'), exist_ok=True) return NetworkManager(data_dir=tmp_dir, config_dir=tmp_dir) def setUp(self): self.tmp = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.tmp) def test_invalid_dns_ip_returns_warning(self): nm = self._get_network_manager(self.tmp) result = nm.add_cell_dns_forward('valid.cell', 'not-an-ip') self.assertTrue(result['warnings']) self.assertFalse(result['restarted']) def test_domain_with_newline_returns_warning(self): nm = self._get_network_manager(self.tmp) result = nm.add_cell_dns_forward('evil\ndomain', '10.1.0.1') self.assertTrue(result['warnings']) self.assertFalse(result['restarted']) def test_domain_with_braces_returns_warning(self): nm = self._get_network_manager(self.tmp) result = nm.add_cell_dns_forward('evil{domain}', '10.1.0.1') self.assertTrue(result['warnings']) self.assertFalse(result['restarted']) def test_domain_with_space_returns_warning(self): nm = self._get_network_manager(self.tmp) result = nm.add_cell_dns_forward('evil domain', '10.1.0.1') self.assertTrue(result['warnings']) self.assertFalse(result['restarted']) def test_valid_domain_and_ip_calls_apply_all_dns_rules(self): """Valid inputs must call firewall_manager.apply_all_dns_rules().""" nm = self._get_network_manager(self.tmp) with patch.object(firewall_manager, 'apply_all_dns_rules', return_value=True) as mock_apply, \ patch.object(firewall_manager, 'reload_coredns', return_value=True): result = nm.add_cell_dns_forward('valid.cell', '10.1.0.1') mock_apply.assert_called_once() call_kwargs = mock_apply.call_args # cell_links kwarg must include the new entry cell_links_arg = call_kwargs[1].get('cell_links') or call_kwargs[0][3] domains = [l['domain'] for l in cell_links_arg] self.assertIn('valid.cell', domains) if __name__ == '__main__': unittest.main()