#!/usr/bin/env python3 """ Unit tests for WireGuardManager class """ import sys from pathlib import Path # Add api directory to path api_dir = Path(__file__).parent.parent / 'api' sys.path.insert(0, str(api_dir)) import unittest import tempfile import os import json import shutil import base64 from unittest.mock import patch, MagicMock from datetime import datetime # Add parent directory to path for imports import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from wireguard_manager import WireGuardManager class TestWireGuardManager(unittest.TestCase): """Test cases for WireGuardManager class""" def setUp(self): """Set up test environment""" self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) # Create WireGuardManager instance self.wg_manager = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): """Clean up test environment""" shutil.rmtree(self.test_dir) def test_initialization(self): """Test WireGuardManager initialization""" self.assertEqual(self.wg_manager.data_dir, self.data_dir) self.assertEqual(self.wg_manager.config_dir, self.config_dir) self.assertTrue(os.path.exists(self.wg_manager.wireguard_dir)) self.assertTrue(os.path.exists(self.wg_manager.keys_dir)) def test_key_generation(self): """Test WireGuard key generation""" # Check if keys were generated private_key_file = os.path.join(self.wg_manager.keys_dir, 'private.key') public_key_file = os.path.join(self.wg_manager.keys_dir, 'public.key') self.assertTrue(os.path.exists(private_key_file)) self.assertTrue(os.path.exists(public_key_file)) # Check key content with open(private_key_file, 'rb') as f: private_key = f.read() self.assertIsInstance(private_key, bytes) self.assertGreater(len(private_key), 0) with open(public_key_file, 'rb') as f: public_key = f.read() self.assertIsInstance(public_key, bytes) self.assertGreater(len(public_key), 0) def test_get_keys(self): """Test getting WireGuard keys""" keys = self.wg_manager.get_keys() self.assertIn('private_key', keys) self.assertIn('public_key', keys) self.assertIsInstance(keys['private_key'], str) self.assertIsInstance(keys['public_key'], str) self.assertGreater(len(keys['private_key']), 0) self.assertGreater(len(keys['public_key']), 0) def test_generate_peer_keys(self): """Test generating keys for a peer""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.assertIn('private_key', peer_keys) self.assertIn('public_key', peer_keys) self.assertIsInstance(peer_keys['private_key'], str) self.assertIsInstance(peer_keys['public_key'], str) # Check if peer keys were saved peer_keys_dir = os.path.join(self.wg_manager.keys_dir, 'peers') peer_private_file = os.path.join(peer_keys_dir, 'testpeer_private.key') peer_public_file = os.path.join(peer_keys_dir, 'testpeer_public.key') self.assertTrue(os.path.exists(peer_private_file)) self.assertTrue(os.path.exists(peer_public_file)) def test_generate_config(self): """Test WireGuard configuration generation""" config = self.wg_manager.generate_config('wg0', 51820) self.assertIsInstance(config, str) self.assertIn('[Interface]', config) self.assertIn('PrivateKey', config) self.assertIn('Address = 10.0.0.1/24', config) self.assertIn('ListenPort = 51820', config) self.assertIn('PostUp', config) self.assertIn('PostDown', config) def test_add_peer(self): """Test adding a peer — server-side AllowedIPs must be /32.""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') success = self.wg_manager.add_peer( 'testpeer', peer_keys['public_key'], '', '10.0.0.2/32', 25 ) self.assertTrue(success) config_file = self.wg_manager._config_file() self.assertTrue(os.path.exists(config_file)) with open(config_file, 'r') as f: config = f.read() self.assertIn('[Peer]', config) self.assertIn(peer_keys['public_key'], config) self.assertIn('AllowedIPs = 10.0.0.2/32', config) self.assertIn('PersistentKeepalive = 25', config) def test_remove_peer(self): """Test removing a peer from WireGuard configuration""" # Add a peer first peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32') # Remove the peer success = self.wg_manager.remove_peer(peer_keys['public_key']) self.assertTrue(success) # Check if peer was removed config_file = self.wg_manager._config_file() with open(config_file, 'r') as f: config = f.read() self.assertNotIn(peer_keys['public_key'], config) def test_get_peers(self): """Test getting list of configured peers""" # Add a peer first peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32') peers = self.wg_manager.get_peers() self.assertIsInstance(peers, list) self.assertEqual(len(peers), 1) self.assertIn('public_key', peers[0]) self.assertIn('allowed_ips', peers[0]) self.assertIn('persistent_keepalive', peers[0]) self.assertEqual(peers[0]['public_key'], peer_keys['public_key']) @patch('subprocess.run') def test_get_status(self, mock_run): """Test getting WireGuard status""" # Mock WireGuard service running mock_run.return_value.stdout = 'cell-wireguard\n' mock_run.return_value.returncode = 0 status = self.wg_manager.get_status() self.assertTrue(status['running']) self.assertIn('interface', status) self.assertIn('ip_info', status) @patch('subprocess.run') def test_get_status_not_running(self, mock_run): """Test getting WireGuard status when service is not running""" # Mock WireGuard service not running mock_run.return_value.stdout = '' mock_run.return_value.returncode = 0 status = self.wg_manager.get_status() self.assertFalse(status['running']) @patch('subprocess.run') def test_test_connectivity(self, mock_run): """Test connectivity testing""" # Mock successful ping mock_run.return_value.returncode = 0 mock_run.return_value.stdout = 'PING 192.168.1.100' mock_run.return_value.stderr = '' result = self.wg_manager.test_connectivity('192.168.1.100') self.assertEqual(result['peer_ip'], '192.168.1.100') self.assertTrue(result['ping_success']) self.assertIn('192.168.1.100', result['ping_output']) @patch('subprocess.run') def test_test_connectivity_failure(self, mock_run): """Test connectivity testing with failure""" # Mock failed ping mock_run.return_value.returncode = 1 mock_run.return_value.stdout = '' mock_run.return_value.stderr = 'No route to host' result = self.wg_manager.test_connectivity('192.168.1.100') self.assertEqual(result['peer_ip'], '192.168.1.100') self.assertFalse(result['ping_success']) self.assertIn('No route to host', result['ping_error']) def test_update_peer_ip(self): """Test updating peer IP address""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32') success = self.wg_manager.update_peer_ip(peer_keys['public_key'], '10.0.0.9/32') self.assertTrue(success) with open(self.wg_manager._config_file(), 'r') as f: config = f.read() self.assertIn('10.0.0.9/32', config) def test_get_peer_config(self): """Test generating peer client configuration.""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') keys = self.wg_manager.get_keys() config = self.wg_manager.get_peer_config('testpeer', '10.0.0.2', peer_keys['private_key']) self.assertIsInstance(config, str) self.assertIn('[Interface]', config) self.assertIn('[Peer]', config) self.assertIn('PrivateKey', config) self.assertIn('Address = 10.0.0.2/32', config) self.assertIn('DNS = 172.20.0.3', config) self.assertIn(keys['public_key'], config) self.assertIn('AllowedIPs', config) def test_multiple_peers(self): """Test managing multiple peers""" peer1_keys = self.wg_manager.generate_peer_keys('peer1') success1 = self.wg_manager.add_peer('peer1', peer1_keys['public_key'], '', '10.0.0.2/32') self.assertTrue(success1) peer2_keys = self.wg_manager.generate_peer_keys('peer2') success2 = self.wg_manager.add_peer('peer2', peer2_keys['public_key'], '', '10.0.0.3/32') self.assertTrue(success2) # Get peers peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 2) # Remove first peer success3 = self.wg_manager.remove_peer(peer1_keys['public_key']) self.assertTrue(success3) # Check remaining peers peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 1) self.assertEqual(peers[0]['public_key'], peer2_keys['public_key']) def test_config_file_parsing(self): """Test parsing WireGuard configuration file""" # Create a test config file config_file = os.path.join(self.wg_manager.wireguard_dir, 'wg0.conf') test_config = """[Interface] PrivateKey = test_private_key Address = 172.20.0.1/16 ListenPort = 51820 [Peer] PublicKey = peer1_public_key AllowedIPs = 172.20.0.0/16 PersistentKeepalive = 25 [Peer] PublicKey = peer2_public_key AllowedIPs = 172.20.1.0/24 PersistentKeepalive = 30 """ with open(config_file, 'w') as f: f.write(test_config) peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 2) self.assertEqual(peers[0]['public_key'], 'peer1_public_key') self.assertEqual(peers[0]['allowed_ips'], '172.20.0.0/16') self.assertEqual(peers[0]['persistent_keepalive'], 25) self.assertEqual(peers[1]['public_key'], 'peer2_public_key') self.assertEqual(peers[1]['allowed_ips'], '172.20.1.0/24') self.assertEqual(peers[1]['persistent_keepalive'], 30) def test_error_handling(self): """Test error handling in WireGuard operations.""" # Wide CIDR rejected — server-side AllowedIPs must be /32 success = self.wg_manager.add_peer('testpeer', 'invalid_key', '', '172.20.0.0/16') self.assertFalse(success, "Wide CIDR must be rejected") # Valid /32 with any key string is accepted (key format not validated at this layer) success = self.wg_manager.add_peer('testpeer', 'any_key_string=', '', '10.0.0.2/32') self.assertTrue(success) # Removing non-existent peer is a no-op, not an error success = self.wg_manager.remove_peer('non_existent_key') self.assertTrue(success) # Updating IP for peer not in config returns False success = self.wg_manager.update_peer_ip('non_existent_key', '10.0.0.9/32') self.assertFalse(success) class TestWireGuardCellPeer(unittest.TestCase): """Test add_cell_peer allows subnet CIDRs for site-to-site connections.""" def setUp(self): self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) self.wg = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): shutil.rmtree(self.test_dir) def test_add_cell_peer_allows_subnet_cidr(self): ok = self.wg.add_cell_peer('remote', 'remotepubkey=', '5.6.7.8:51821', '10.1.0.0/24') self.assertTrue(ok) content = self.wg._read_config() self.assertIn('10.1.0.0/24', content) def test_add_cell_peer_writes_full_endpoint(self): self.wg.add_cell_peer('remote', 'remotepubkey=', '5.6.7.8:51821', '10.1.0.0/24') content = self.wg._read_config() self.assertIn('Endpoint = 5.6.7.8:51821', content) def test_add_cell_peer_comment_has_cell_prefix(self): self.wg.add_cell_peer('remote', 'remotepubkey=', '5.6.7.8:51821', '10.1.0.0/24') content = self.wg._read_config() self.assertIn('# cell:remote', content) def test_add_cell_peer_invalid_cidr_returns_false(self): ok = self.wg.add_cell_peer('remote', 'remotepubkey=', '5.6.7.8:51821', 'not-a-cidr') self.assertFalse(ok) def test_add_cell_peer_can_coexist_with_regular_peers(self): self.wg.add_peer('alice', 'alicepubkey=', '', '10.0.0.2/32') self.wg.add_cell_peer('remote', 'remotepubkey=', '5.6.7.8:51821', '10.1.0.0/24') content = self.wg._read_config() self.assertIn('alicepubkey=', content) self.assertIn('remotepubkey=', content) class TestWireGuardConfigReads(unittest.TestCase): """Test that port/address/network are read from wg0.conf, not hardcoded.""" def setUp(self): self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) patcher = patch.object(WireGuardManager, '_syncconf', return_value=None) self.mock_sync = patcher.start() self.addCleanup(patcher.stop) self.wg = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): shutil.rmtree(self.test_dir) def _write_wg_conf(self, port=51820, address='10.0.0.1/24', extra=''): conf = ( f'[Interface]\n' f'PrivateKey = dummykey\n' f'Address = {address}\n' f'ListenPort = {port}\n' f'{extra}' ) cf = self.wg._config_file() os.makedirs(os.path.dirname(cf), exist_ok=True) with open(cf, 'w') as f: f.write(conf) def test_get_configured_port_reads_from_wg_conf(self): self._write_wg_conf(port=54321) self.assertEqual(self.wg._get_configured_port(), 54321) def test_get_configured_port_fallback_when_no_file(self): # No wg0.conf exists — fall back to DEFAULT_PORT self.assertEqual(self.wg._get_configured_port(), 51820) def test_get_configured_address_reads_from_wg_conf(self): self._write_wg_conf(address='10.1.0.1/24') self.assertEqual(self.wg._get_configured_address(), '10.1.0.1/24') def test_get_configured_network_derives_from_address(self): self._write_wg_conf(address='10.1.0.1/24') self.assertEqual(self.wg._get_configured_network(), '10.1.0.0/24') def test_get_split_tunnel_ips_uses_configured_network(self): self._write_wg_conf(address='10.1.0.1/24') split = self.wg.get_split_tunnel_ips() self.assertIn('10.1.0.0/24', split) self.assertIn('172.20.0.0/16', split) self.assertNotIn('10.0.0.0/24', split) def test_get_server_config_uses_configured_port(self): self._write_wg_conf(port=54321) with patch.object(self.wg, 'get_external_ip', return_value='1.2.3.4'): cfg = self.wg.get_server_config() self.assertEqual(cfg['port'], 54321) self.assertIn(':54321', cfg['endpoint']) def test_get_server_config_includes_dns_and_split_tunnel(self): self._write_wg_conf(address='10.2.0.1/24') with patch.object(self.wg, 'get_external_ip', return_value='1.2.3.4'): cfg = self.wg.get_server_config() self.assertIn('dns_ip', cfg) self.assertIn('split_tunnel_ips', cfg) self.assertIn('10.2.0.0/24', cfg['split_tunnel_ips']) def test_get_peer_config_uses_configured_port_in_endpoint(self): self._write_wg_conf(port=54321) result = self.wg.get_peer_config( peer_name='alice', peer_ip='10.0.0.2', peer_private_key='privkeyalice=', server_endpoint='5.6.7.8', ) self.assertIn(':54321', result) self.assertNotIn(':51820', result) def test_add_peer_uses_configured_port_in_endpoint(self): self._write_wg_conf(port=54321) self.wg.add_peer('alice', 'pubkeyalice=', '5.6.7.8', '10.0.0.2/32') content = self.wg._read_config() self.assertIn('Endpoint = 5.6.7.8:54321', content) self.assertNotIn(':51820', content) if __name__ == '__main__': unittest.main()