#!/usr/bin/env python3 """ Unit tests for WireGuardManager class """ import sys from pathlib import Path # Add api directory to path api_dir = Path(__file__).parent.parent / 'api' sys.path.insert(0, str(api_dir)) import unittest import tempfile import os import json import shutil import base64 from unittest.mock import patch, MagicMock from datetime import datetime # Add parent directory to path for imports import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from wireguard_manager import WireGuardManager class TestWireGuardManager(unittest.TestCase): """Test cases for WireGuardManager class""" def setUp(self): """Set up test environment""" self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) # Create WireGuardManager instance self.wg_manager = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): """Clean up test environment""" shutil.rmtree(self.test_dir) def test_initialization(self): """Test WireGuardManager initialization""" self.assertEqual(self.wg_manager.data_dir, self.data_dir) self.assertEqual(self.wg_manager.config_dir, self.config_dir) self.assertTrue(os.path.exists(self.wg_manager.wireguard_dir)) self.assertTrue(os.path.exists(self.wg_manager.keys_dir)) def test_key_generation(self): """Test WireGuard key generation""" # Check if keys were generated private_key_file = os.path.join(self.wg_manager.keys_dir, 'private.key') public_key_file = os.path.join(self.wg_manager.keys_dir, 'public.key') self.assertTrue(os.path.exists(private_key_file)) self.assertTrue(os.path.exists(public_key_file)) # Check key content with open(private_key_file, 'rb') as f: private_key = f.read() self.assertIsInstance(private_key, bytes) self.assertGreater(len(private_key), 0) with open(public_key_file, 'rb') as f: public_key = f.read() self.assertIsInstance(public_key, bytes) self.assertGreater(len(public_key), 0) def test_get_keys(self): """Test getting WireGuard keys""" keys = self.wg_manager.get_keys() self.assertIn('private_key', keys) self.assertIn('public_key', keys) self.assertIsInstance(keys['private_key'], str) self.assertIsInstance(keys['public_key'], str) self.assertGreater(len(keys['private_key']), 0) self.assertGreater(len(keys['public_key']), 0) def test_generate_peer_keys(self): """Test generating keys for a peer""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.assertIn('private_key', peer_keys) self.assertIn('public_key', peer_keys) self.assertIsInstance(peer_keys['private_key'], str) self.assertIsInstance(peer_keys['public_key'], str) # Check if peer keys were saved peer_keys_dir = os.path.join(self.wg_manager.keys_dir, 'peers') peer_private_file = os.path.join(peer_keys_dir, 'testpeer_private.key') peer_public_file = os.path.join(peer_keys_dir, 'testpeer_public.key') self.assertTrue(os.path.exists(peer_private_file)) self.assertTrue(os.path.exists(peer_public_file)) def test_generate_config(self): """Test WireGuard configuration generation""" config = self.wg_manager.generate_config('wg0', 51820) self.assertIsInstance(config, str) self.assertIn('[Interface]', config) self.assertIn('PrivateKey', config) self.assertIn('Address = 10.0.0.1/24', config) self.assertIn('ListenPort = 51820', config) self.assertIn('PostUp', config) self.assertIn('PostDown', config) def test_add_peer(self): """Test adding a peer — server-side AllowedIPs must be /32.""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') success = self.wg_manager.add_peer( 'testpeer', peer_keys['public_key'], '', '10.0.0.2/32', 25 ) self.assertTrue(success) config_file = self.wg_manager._config_file() self.assertTrue(os.path.exists(config_file)) with open(config_file, 'r') as f: config = f.read() self.assertIn('[Peer]', config) self.assertIn(peer_keys['public_key'], config) self.assertIn('AllowedIPs = 10.0.0.2/32', config) self.assertIn('PersistentKeepalive = 25', config) def test_remove_peer(self): """Test removing a peer from WireGuard configuration""" # Add a peer first peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32') # Remove the peer success = self.wg_manager.remove_peer(peer_keys['public_key']) self.assertTrue(success) # Check if peer was removed config_file = self.wg_manager._config_file() with open(config_file, 'r') as f: config = f.read() self.assertNotIn(peer_keys['public_key'], config) def test_get_peers(self): """Test getting list of configured peers""" # Add a peer first peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32') peers = self.wg_manager.get_peers() self.assertIsInstance(peers, list) self.assertEqual(len(peers), 1) self.assertIn('public_key', peers[0]) self.assertIn('allowed_ips', peers[0]) self.assertIn('persistent_keepalive', peers[0]) self.assertEqual(peers[0]['public_key'], peer_keys['public_key']) @patch('subprocess.run') def test_get_status(self, mock_run): """Test getting WireGuard status""" # Mock WireGuard service running mock_run.return_value.stdout = 'cell-wireguard\n' mock_run.return_value.returncode = 0 status = self.wg_manager.get_status() self.assertTrue(status['running']) self.assertIn('interface', status) self.assertIn('ip_info', status) @patch('subprocess.run') def test_get_status_not_running(self, mock_run): """Test getting WireGuard status when service is not running""" # Mock WireGuard service not running mock_run.return_value.stdout = '' mock_run.return_value.returncode = 0 status = self.wg_manager.get_status() self.assertFalse(status['running']) @patch('subprocess.run') def test_test_connectivity(self, mock_run): """Test connectivity testing""" # Mock successful ping mock_run.return_value.returncode = 0 mock_run.return_value.stdout = 'PING 192.168.1.100' mock_run.return_value.stderr = '' result = self.wg_manager.test_connectivity('192.168.1.100') self.assertEqual(result['peer_ip'], '192.168.1.100') self.assertTrue(result['ping_success']) self.assertIn('192.168.1.100', result['ping_output']) @patch('subprocess.run') def test_test_connectivity_failure(self, mock_run): """Test connectivity testing with failure""" # Mock failed ping mock_run.return_value.returncode = 1 mock_run.return_value.stdout = '' mock_run.return_value.stderr = 'No route to host' result = self.wg_manager.test_connectivity('192.168.1.100') self.assertEqual(result['peer_ip'], '192.168.1.100') self.assertFalse(result['ping_success']) self.assertIn('No route to host', result['ping_error']) def test_update_peer_ip(self): """Test updating peer IP address""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32') success = self.wg_manager.update_peer_ip(peer_keys['public_key'], '10.0.0.9/32') self.assertTrue(success) with open(self.wg_manager._config_file(), 'r') as f: config = f.read() self.assertIn('10.0.0.9/32', config) def test_get_peer_config(self): """Test generating peer client configuration.""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') keys = self.wg_manager.get_keys() config = self.wg_manager.get_peer_config('testpeer', '10.0.0.2', peer_keys['private_key']) self.assertIsInstance(config, str) self.assertIn('[Interface]', config) self.assertIn('[Peer]', config) self.assertIn('PrivateKey', config) self.assertIn('Address = 10.0.0.2/32', config) self.assertIn('DNS = 172.20.0.3', config) self.assertIn(keys['public_key'], config) self.assertIn('AllowedIPs', config) def test_multiple_peers(self): """Test managing multiple peers""" peer1_keys = self.wg_manager.generate_peer_keys('peer1') success1 = self.wg_manager.add_peer('peer1', peer1_keys['public_key'], '', '10.0.0.2/32') self.assertTrue(success1) peer2_keys = self.wg_manager.generate_peer_keys('peer2') success2 = self.wg_manager.add_peer('peer2', peer2_keys['public_key'], '', '10.0.0.3/32') self.assertTrue(success2) # Get peers peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 2) # Remove first peer success3 = self.wg_manager.remove_peer(peer1_keys['public_key']) self.assertTrue(success3) # Check remaining peers peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 1) self.assertEqual(peers[0]['public_key'], peer2_keys['public_key']) def test_config_file_parsing(self): """Test parsing WireGuard configuration file""" # Create a test config file config_file = os.path.join(self.wg_manager.wireguard_dir, 'wg0.conf') test_config = """[Interface] PrivateKey = test_private_key Address = 172.20.0.1/16 ListenPort = 51820 [Peer] PublicKey = peer1_public_key AllowedIPs = 172.20.0.0/16 PersistentKeepalive = 25 [Peer] PublicKey = peer2_public_key AllowedIPs = 172.20.1.0/24 PersistentKeepalive = 30 """ with open(config_file, 'w') as f: f.write(test_config) peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 2) self.assertEqual(peers[0]['public_key'], 'peer1_public_key') self.assertEqual(peers[0]['allowed_ips'], '172.20.0.0/16') self.assertEqual(peers[0]['persistent_keepalive'], 25) self.assertEqual(peers[1]['public_key'], 'peer2_public_key') self.assertEqual(peers[1]['allowed_ips'], '172.20.1.0/24') self.assertEqual(peers[1]['persistent_keepalive'], 30) def test_error_handling(self): """Test error handling in WireGuard operations.""" # Wide CIDR rejected — server-side AllowedIPs must be /32 success = self.wg_manager.add_peer('testpeer', 'invalid_key', '', '172.20.0.0/16') self.assertFalse(success, "Wide CIDR must be rejected") # Valid /32 with any key string is accepted (key format not validated at this layer) success = self.wg_manager.add_peer('testpeer', 'YW55X2tleV9zdHJpbmdfZm9yX3Rlc3RzX3dnMTIzISE=', '', '10.0.0.2/32') self.assertTrue(success) # Removing non-existent peer is a no-op, not an error success = self.wg_manager.remove_peer('non_existent_key') self.assertTrue(success) # Updating IP for peer not in config returns False success = self.wg_manager.update_peer_ip('non_existent_key', '10.0.0.9/32') self.assertFalse(success) class TestWireGuardCellPeer(unittest.TestCase): """Test add_cell_peer allows subnet CIDRs for site-to-site connections.""" def setUp(self): self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) self.wg = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): shutil.rmtree(self.test_dir) def test_add_cell_peer_allows_subnet_cidr(self): ok = self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24') self.assertTrue(ok) content = self.wg._read_config() self.assertIn('10.1.0.0/24', content) def test_add_cell_peer_writes_full_endpoint(self): self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24') content = self.wg._read_config() self.assertIn('Endpoint = 5.6.7.8:51821', content) def test_add_cell_peer_comment_has_cell_prefix(self): self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24') content = self.wg._read_config() self.assertIn('# cell:remote', content) def test_add_cell_peer_invalid_cidr_returns_false(self): ok = self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', 'not-a-cidr') self.assertFalse(ok) def test_add_cell_peer_can_coexist_with_regular_peers(self): self.wg.add_peer('alice', 'YWxpY2VwdWJrZXlfZm9yX3Rlc3RzX3dndGVzdDEyMyE=', '', '10.0.0.2/32') self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24') content = self.wg._read_config() self.assertIn('YWxpY2VwdWJrZXlfZm9yX3Rlc3RzX3dndGVzdDEyMyE=', content) self.assertIn('cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', content) class TestWireGuardConfigReads(unittest.TestCase): """Test that port/address/network are read from wg0.conf, not hardcoded.""" def setUp(self): self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) self.wg = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): shutil.rmtree(self.test_dir) def _write_wg_conf(self, port=51820, address='10.0.0.1/24', extra=''): conf = ( f'[Interface]\n' f'PrivateKey = dummykey\n' f'Address = {address}\n' f'ListenPort = {port}\n' f'{extra}' ) cf = self.wg._config_file() os.makedirs(os.path.dirname(cf), exist_ok=True) with open(cf, 'w') as f: f.write(conf) def test_get_configured_port_reads_from_wg_conf(self): self._write_wg_conf(port=54321) self.assertEqual(self.wg._get_configured_port(), 54321) def test_get_configured_port_fallback_when_no_file(self): # No wg0.conf exists — fall back to DEFAULT_PORT self.assertEqual(self.wg._get_configured_port(), 51820) def test_get_configured_address_reads_from_wg_conf(self): self._write_wg_conf(address='10.1.0.1/24') self.assertEqual(self.wg._get_configured_address(), '10.1.0.1/24') def test_get_configured_network_derives_from_address(self): self._write_wg_conf(address='10.1.0.1/24') self.assertEqual(self.wg._get_configured_network(), '10.1.0.0/24') def test_get_split_tunnel_ips_uses_configured_network(self): self._write_wg_conf(address='10.1.0.1/24') split = self.wg.get_split_tunnel_ips() self.assertIn('10.1.0.0/24', split) self.assertIn('172.20.0.0/16', split) self.assertNotIn('10.0.0.0/24', split) def test_get_server_config_uses_configured_port(self): self._write_wg_conf(port=54321) with patch.object(self.wg, 'get_external_ip', return_value='1.2.3.4'): cfg = self.wg.get_server_config() self.assertEqual(cfg['port'], 54321) self.assertIn(':54321', cfg['endpoint']) def test_get_server_config_includes_dns_and_split_tunnel(self): self._write_wg_conf(address='10.2.0.1/24') with patch.object(self.wg, 'get_external_ip', return_value='1.2.3.4'): cfg = self.wg.get_server_config() self.assertIn('dns_ip', cfg) self.assertIn('split_tunnel_ips', cfg) self.assertIn('10.2.0.0/24', cfg['split_tunnel_ips']) def test_get_peer_config_uses_configured_port_in_endpoint(self): self._write_wg_conf(port=54321) result = self.wg.get_peer_config( peer_name='alice', peer_ip='10.0.0.2', peer_private_key='privkeyalice=', server_endpoint='5.6.7.8', ) self.assertIn(':54321', result) self.assertNotIn(':51820', result) def test_add_peer_uses_configured_port_in_endpoint(self): self._write_wg_conf(port=54321) self.wg.add_peer('alice', 'cHVia2V5YWxpY2VfZm9yX3Rlc3RzX3dpcmVndWFyZCE=', '5.6.7.8', '10.0.0.2/32') content = self.wg._read_config() self.assertIn('Endpoint = 5.6.7.8:54321', content) self.assertNotIn(':51820', content) class TestWireGuardSysctlAndPortCheck(unittest.TestCase): """Tests for sysctl safety, port check, and peer status parsing.""" def setUp(self): self.test_dir = tempfile.mkdtemp() patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) self.addCleanup(shutil.rmtree, self.test_dir) self.wg = WireGuardManager(self.test_dir, self.test_dir) # ── generate_config sysctl safety ──────────────────────────────────────── def test_generate_config_postup_has_nonfatal_sysctl(self): cfg = self.wg.generate_config() self.assertIn('sysctl -q net.ipv4.conf.all.rp_filter=0 || true', cfg) def test_generate_config_postdown_has_nonfatal_sysctl(self): cfg = self.wg.generate_config() self.assertIn('sysctl -q net.ipv4.conf.all.rp_filter=1 || true', cfg) def test_generate_config_has_masquerade(self): cfg = self.wg.generate_config() self.assertIn('MASQUERADE', cfg) def test_generate_config_has_forward_rule(self): cfg = self.wg.generate_config() self.assertIn('FORWARD -i %i -j ACCEPT', cfg) # ── check_port_open ─────────────────────────────────────────────────────── @patch('subprocess.run') def test_check_port_open_when_wg_interface_up(self, mock_run): mock_run.return_value.returncode = 0 mock_run.return_value.stdout = 'interface: wg0\n listening port: 51820\n' self.assertTrue(self.wg.check_port_open()) @patch('subprocess.run') def test_check_port_open_false_when_interface_down(self, mock_run): # wg show fails (no device), fallback wg show dump also fails mock_run.return_value.returncode = 1 mock_run.return_value.stdout = '' self.assertFalse(self.wg.check_port_open()) @patch('subprocess.run') def test_check_port_open_fallback_to_recent_handshake(self, mock_run): # First call (wg show wg0): fails — interface not reported as up # Second call (wg show wg0 dump): returns a peer with recent handshake import time as _time now = int(_time.time()) dump_line = f'pubkey\t(none)\t1.2.3.4:51820\t0.0.0.0/0\t{now - 10}\t1000\t2000\t25\n' def side_effect(*args, **kwargs): cmd = args[0] m = MagicMock() if 'dump' in cmd: m.returncode = 0 m.stdout = dump_line else: m.returncode = 0 m.stdout = 'interface: wg0\n' # no "listening port" text return m mock_run.side_effect = side_effect # "listening port" not in stdout for first call → falls through to dump # dump has recent handshake → returns True result = self.wg.check_port_open() self.assertTrue(result) @patch('subprocess.run') def test_check_port_open_wrong_port_returns_false(self, mock_run): # wg0 is up but listening on 51820 while wg0.conf says 51821 — must return False mock_run.return_value.returncode = 0 mock_run.return_value.stdout = 'interface: wg0\n listening port: 51820\n' # Write wg0.conf with a different port so _get_configured_port() returns 51821 cfg_path = os.path.join(self.wg.wireguard_dir, 'wg0.conf') with open(cfg_path, 'w') as f: f.write('[Interface]\nListenPort = 51821\nPrivateKey = abc\n') self.assertFalse(self.wg.check_port_open()) @patch('subprocess.run') def test_check_port_open_explicit_port_matches(self, mock_run): mock_run.return_value.returncode = 0 mock_run.return_value.stdout = 'interface: wg0\n listening port: 12345\n' self.assertTrue(self.wg.check_port_open(port=12345)) @patch('subprocess.run') def test_check_port_open_explicit_port_mismatch(self, mock_run): mock_run.return_value.returncode = 0 mock_run.return_value.stdout = 'interface: wg0\n listening port: 51820\n' self.assertFalse(self.wg.check_port_open(port=51821)) # ── get_peer_status ─────────────────────────────────────────────────────── @patch('subprocess.run') def test_get_peer_status_online_with_recent_handshake(self, mock_run): import time as _time now = int(_time.time()) pub = 'AAABBBCCC=' dump = ( f'privkey\tserverpub\t51820\toff\n' # interface line (4 fields) f'{pub}\t(none)\t1.2.3.4:12345\t10.0.0.2/32\t{now-30}\t500\t1000\t25\n' ) mock_run.return_value.returncode = 0 mock_run.return_value.stdout = dump st = self.wg.get_peer_status(pub) self.assertTrue(st['online']) self.assertIsNotNone(st['last_handshake']) self.assertLessEqual(st['last_handshake_seconds_ago'], 35) @patch('subprocess.run') def test_get_peer_status_offline_with_old_handshake(self, mock_run): import time as _time now = int(_time.time()) pub = 'AAABBBCCC=' dump = f'{pub}\t(none)\t(none)\t10.0.0.2/32\t{now - 300}\t0\t0\t25\n' mock_run.return_value.returncode = 0 mock_run.return_value.stdout = dump st = self.wg.get_peer_status(pub) self.assertFalse(st['online']) @patch('subprocess.run') def test_get_peer_status_not_found_returns_none_online(self, mock_run): mock_run.return_value.returncode = 0 mock_run.return_value.stdout = '' st = self.wg.get_peer_status('NOTEXIST=') self.assertIsNone(st['online']) @patch('subprocess.run') def test_get_peer_status_no_handshake_yet(self, mock_run): pub = 'AAABBBCCC=' dump = f'{pub}\t(none)\t(none)\t10.0.0.2/32\t0\t0\t0\t25\n' mock_run.return_value.returncode = 0 mock_run.return_value.stdout = dump st = self.wg.get_peer_status(pub) self.assertFalse(st['online']) self.assertIsNone(st['last_handshake']) # ── get_all_peer_statuses ───────────────────────────────────────────────── @patch('subprocess.run') def test_get_all_peer_statuses_parses_multiple_peers(self, mock_run): import time as _time now = int(_time.time()) pub1 = 'PUB1KEY=' pub2 = 'PUB2KEY=' dump = ( f'privkey\tserverpub\t51820\toff\n' f'{pub1}\t(none)\t1.1.1.1:1000\t10.0.0.2/32\t{now-20}\t100\t200\t25\n' f'{pub2}\t(none)\t(none)\t10.0.0.3/32\t{now-200}\t0\t0\t25\n' ) mock_run.return_value.returncode = 0 mock_run.return_value.stdout = dump statuses = self.wg.get_all_peer_statuses() self.assertIn(pub1, statuses) self.assertIn(pub2, statuses) self.assertTrue(statuses[pub1]['online']) self.assertFalse(statuses[pub2]['online']) @patch('subprocess.run') def test_get_all_peer_statuses_empty_when_interface_down(self, mock_run): mock_run.return_value.returncode = 1 mock_run.return_value.stdout = '' statuses = self.wg.get_all_peer_statuses() self.assertEqual(statuses, {}) @patch('subprocess.run') def test_get_all_peer_statuses_skips_interface_line(self, mock_run): # Interface line has only 4 tab-separated fields — must not appear as a peer dump = 'privkey\tserverpub\t51820\toff\n' mock_run.return_value.returncode = 0 mock_run.return_value.stdout = dump statuses = self.wg.get_all_peer_statuses() self.assertEqual(statuses, {}) class TestAddCellPeerSubnetOverlap(unittest.TestCase): """Verify that add_cell_peer rejects a vpn_subnet that overlaps the local WG network.""" def setUp(self): self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) self.wg = WireGuardManager(self.data_dir, self.config_dir) # Write a known wg0.conf so _get_configured_network() returns 10.0.0.0/24 self._write_wg_conf(address='10.0.0.1/24') def tearDown(self): shutil.rmtree(self.test_dir) def _write_wg_conf(self, address='10.0.0.1/24', port=51820): conf = ( f'[Interface]\n' f'PrivateKey = dummykey\n' f'Address = {address}\n' f'ListenPort = {port}\n' ) cf = self.wg._config_file() os.makedirs(os.path.dirname(cf), exist_ok=True) with open(cf, 'w') as f: f.write(conf) # Public key is 44 chars ending in '=' — required by validation in add_cell_peer _CELL_PUBKEY = 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=' def test_add_cell_peer_overlapping_subnet_returns_false(self): """vpn_subnet that exactly matches the local WG network must be rejected.""" # local is 10.0.0.0/24; remote is also 10.0.0.0/24 — clear overlap ok = self.wg.add_cell_peer( 'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.0.0.0/24' ) self.assertFalse(ok) def test_add_cell_peer_partially_overlapping_subnet_returns_false(self): """A remote subnet that contains the local network (e.g. /16 ⊃ /24) is rejected.""" # 10.0.0.0/16 contains 10.0.0.0/24 → overlaps ok = self.wg.add_cell_peer( 'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.0.0.0/16' ) self.assertFalse(ok) def test_add_cell_peer_non_overlapping_subnet_accepted(self): """A remote subnet distinct from the local WG network must be accepted.""" # local is 10.0.0.0/24; remote is 10.0.1.0/24 — no overlap ok = self.wg.add_cell_peer( 'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.0.1.0/24' ) self.assertTrue(ok) def test_add_cell_peer_no_overlap_different_class_a(self): """A completely different address space is accepted.""" # local is 10.0.0.0/24; remote is 192.168.5.0/24 — no overlap ok = self.wg.add_cell_peer( 'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '192.168.5.0/24' ) self.assertTrue(ok) def test_add_cell_peer_overlap_check_uses_configured_network(self): """When wg0.conf says 172.16.0.1/12, overlapping that range is rejected.""" self._write_wg_conf(address='172.16.0.1/12') ok = self.wg.add_cell_peer( 'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '172.16.0.0/12' ) self.assertFalse(ok) class TestCellRoutes(unittest.TestCase): """Tests for _ensure_cell_route and sync_cell_routes.""" _CELL_PUBKEY = 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=' def setUp(self): self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) self.wg = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): shutil.rmtree(self.test_dir) def test_ensure_cell_route_noop_in_test_dir(self): """_ensure_cell_route must not call subprocess when config is in /tmp (test env).""" with patch('subprocess.run') as mock_run: self.wg._ensure_cell_route('10.1.0.0/24') mock_run.assert_not_called() def test_sync_cell_routes_noop_in_test_dir(self): """sync_cell_routes must not call subprocess when config is in /tmp (test env).""" with patch('subprocess.run') as mock_run: self.wg.sync_cell_routes() mock_run.assert_not_called() def test_ensure_cell_route_calls_ip_route_add(self): """Outside test dirs, _ensure_cell_route calls docker exec ip route add.""" with patch.object(self.wg, '_config_file', return_value='/app/config/wireguard/wg0.conf'): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0) self.wg._ensure_cell_route('10.1.0.0/24') mock_run.assert_called_once() cmd = mock_run.call_args[0][0] self.assertIn('ip', cmd) self.assertIn('route', cmd) self.assertIn('add', cmd) self.assertIn('10.1.0.0/24', cmd) self.assertIn('wg0', cmd) def test_sync_cell_routes_finds_cell_peers_in_config(self): """sync_cell_routes parses wg0.conf and adds routes for cell peers only.""" conf = ( '[Interface]\nPrivateKey = dummykey\nAddress = 10.0.0.1/24\nListenPort = 51820\n\n' '[Peer]\n# cell:remote\nPublicKey = cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=\n' 'AllowedIPs = 10.1.0.0/24\nPersistentKeepalive = 25\n\n' '[Peer]\n# alice\nPublicKey = YWxpY2VwdWJrZXlfZm9yX3Rlc3RzX3dndGVzdDEyMyE=\n' 'AllowedIPs = 10.0.0.2/32\nPersistentKeepalive = 25\n' ) with patch.object(self.wg, '_config_file', return_value='/app/config/wireguard/wg0.conf'): with patch.object(self.wg, '_read_config', return_value=conf): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0) self.wg.sync_cell_routes() calls = [c[0][0] for c in mock_run.call_args_list] subnets = [c for c in calls if '10.1.0.0/24' in c] non_cell = [c for c in calls if '10.0.0.2/32' in c] self.assertTrue(len(subnets) >= 1, 'expected route add for cell peer subnet') self.assertEqual(len(non_cell), 0, 'should not add route for regular peer') def test_add_cell_peer_triggers_ensure_cell_route(self): """add_cell_peer calls _ensure_cell_route after writing config.""" with patch.object(self.wg, '_ensure_cell_route') as mock_route: self.wg.add_cell_peer('remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.1.0.0/24') mock_route.assert_called_once_with('10.1.0.0/24') if __name__ == '__main__': unittest.main()