#!/usr/bin/env python3 """ Unit tests for WireGuardManager class """ import sys from pathlib import Path # Add api directory to path api_dir = Path(__file__).parent.parent / 'api' sys.path.insert(0, str(api_dir)) import unittest import tempfile import os import json import shutil import base64 from unittest.mock import patch, MagicMock from datetime import datetime # Add parent directory to path for imports import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from wireguard_manager import WireGuardManager class TestWireGuardManager(unittest.TestCase): """Test cases for WireGuardManager class""" def setUp(self): """Set up test environment""" self.test_dir = tempfile.mkdtemp() self.data_dir = os.path.join(self.test_dir, 'data') self.config_dir = os.path.join(self.test_dir, 'config') os.makedirs(self.data_dir, exist_ok=True) os.makedirs(self.config_dir, exist_ok=True) # Create WireGuardManager instance self.wg_manager = WireGuardManager(self.data_dir, self.config_dir) def tearDown(self): """Clean up test environment""" shutil.rmtree(self.test_dir) def test_initialization(self): """Test WireGuardManager initialization""" self.assertEqual(self.wg_manager.data_dir, self.data_dir) self.assertEqual(self.wg_manager.config_dir, self.config_dir) self.assertTrue(os.path.exists(self.wg_manager.wireguard_dir)) self.assertTrue(os.path.exists(self.wg_manager.keys_dir)) def test_key_generation(self): """Test WireGuard key generation""" # Check if keys were generated private_key_file = os.path.join(self.wg_manager.keys_dir, 'private.key') public_key_file = os.path.join(self.wg_manager.keys_dir, 'public.key') self.assertTrue(os.path.exists(private_key_file)) self.assertTrue(os.path.exists(public_key_file)) # Check key content with open(private_key_file, 'rb') as f: private_key = f.read() self.assertIsInstance(private_key, bytes) self.assertGreater(len(private_key), 0) with open(public_key_file, 'rb') as f: public_key = f.read() self.assertIsInstance(public_key, bytes) self.assertGreater(len(public_key), 0) def test_get_keys(self): """Test getting WireGuard keys""" keys = self.wg_manager.get_keys() self.assertIn('private_key', keys) self.assertIn('public_key', keys) self.assertIsInstance(keys['private_key'], str) self.assertIsInstance(keys['public_key'], str) self.assertGreater(len(keys['private_key']), 0) self.assertGreater(len(keys['public_key']), 0) def test_generate_peer_keys(self): """Test generating keys for a peer""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.assertIn('private_key', peer_keys) self.assertIn('public_key', peer_keys) self.assertIsInstance(peer_keys['private_key'], str) self.assertIsInstance(peer_keys['public_key'], str) # Check if peer keys were saved peer_keys_dir = os.path.join(self.wg_manager.keys_dir, 'peers') peer_private_file = os.path.join(peer_keys_dir, 'testpeer_private.key') peer_public_file = os.path.join(peer_keys_dir, 'testpeer_public.key') self.assertTrue(os.path.exists(peer_private_file)) self.assertTrue(os.path.exists(peer_public_file)) def test_generate_config(self): """Test WireGuard configuration generation""" config = self.wg_manager.generate_config('wg0', 51820) self.assertIsInstance(config, str) self.assertIn('[Interface]', config) self.assertIn('PrivateKey', config) self.assertIn('Address = 172.20.0.1/16', config) self.assertIn('ListenPort = 51820', config) self.assertIn('PostUp', config) self.assertIn('PostDown', config) def test_add_peer(self): """Test adding a peer to WireGuard configuration""" # Generate peer keys first peer_keys = self.wg_manager.generate_peer_keys('testpeer') success = self.wg_manager.add_peer( 'testpeer', peer_keys['public_key'], '192.168.1.100', '172.20.0.0/16', 25 ) self.assertTrue(success) # Check if config file was created config_file = os.path.join(self.wg_manager.wireguard_dir, 'wg0.conf') self.assertTrue(os.path.exists(config_file)) # Check config content with open(config_file, 'r') as f: config = f.read() self.assertIn('[Peer]', config) self.assertIn(peer_keys['public_key'], config) self.assertIn('AllowedIPs = 172.20.0.0/16', config) self.assertIn('PersistentKeepalive = 25', config) def test_remove_peer(self): """Test removing a peer from WireGuard configuration""" # Add a peer first peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '192.168.1.100') # Remove the peer success = self.wg_manager.remove_peer(peer_keys['public_key']) self.assertTrue(success) # Check if peer was removed config_file = os.path.join(self.wg_manager.wireguard_dir, 'wg0.conf') with open(config_file, 'r') as f: config = f.read() self.assertNotIn(peer_keys['public_key'], config) def test_get_peers(self): """Test getting list of configured peers""" # Add a peer first peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '192.168.1.100') peers = self.wg_manager.get_peers() self.assertIsInstance(peers, list) self.assertEqual(len(peers), 1) self.assertIn('public_key', peers[0]) self.assertIn('allowed_ips', peers[0]) self.assertIn('persistent_keepalive', peers[0]) self.assertEqual(peers[0]['public_key'], peer_keys['public_key']) @patch('subprocess.run') def test_get_status(self, mock_run): """Test getting WireGuard status""" # Mock WireGuard service running mock_run.return_value.stdout = 'cell-wireguard\n' mock_run.return_value.returncode = 0 status = self.wg_manager.get_status() self.assertTrue(status['running']) self.assertIn('interface', status) self.assertIn('ip_info', status) @patch('subprocess.run') def test_get_status_not_running(self, mock_run): """Test getting WireGuard status when service is not running""" # Mock WireGuard service not running mock_run.return_value.stdout = '' mock_run.return_value.returncode = 0 status = self.wg_manager.get_status() self.assertFalse(status['running']) @patch('subprocess.run') def test_test_connectivity(self, mock_run): """Test connectivity testing""" # Mock successful ping mock_run.return_value.returncode = 0 mock_run.return_value.stdout = 'PING 192.168.1.100' mock_run.return_value.stderr = '' result = self.wg_manager.test_connectivity('192.168.1.100') self.assertEqual(result['peer_ip'], '192.168.1.100') self.assertTrue(result['ping_success']) self.assertIn('192.168.1.100', result['ping_output']) @patch('subprocess.run') def test_test_connectivity_failure(self, mock_run): """Test connectivity testing with failure""" # Mock failed ping mock_run.return_value.returncode = 1 mock_run.return_value.stdout = '' mock_run.return_value.stderr = 'No route to host' result = self.wg_manager.test_connectivity('192.168.1.100') self.assertEqual(result['peer_ip'], '192.168.1.100') self.assertFalse(result['ping_success']) self.assertIn('No route to host', result['ping_error']) def test_update_peer_ip(self): """Test updating peer IP address""" # Add a peer first peer_keys = self.wg_manager.generate_peer_keys('testpeer') self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '192.168.1.100') # Update peer IP success = self.wg_manager.update_peer_ip(peer_keys['public_key'], '192.168.1.200') self.assertTrue(success) # Check if IP was updated in config config_file = os.path.join(self.wg_manager.wireguard_dir, 'wg0.conf') with open(config_file, 'r') as f: config = f.read() self.assertIn('192.168.1.200', config) def test_get_peer_config(self): """Test generating peer configuration""" peer_keys = self.wg_manager.generate_peer_keys('testpeer') keys = self.wg_manager.get_keys() config = self.wg_manager.get_peer_config('testpeer', '192.168.1.100', peer_keys['private_key']) self.assertIsInstance(config, str) self.assertIn('[Interface]', config) self.assertIn('[Peer]', config) self.assertIn('PrivateKey', config) self.assertIn('Address = 192.168.1.100/32', config) self.assertIn('DNS = 172.20.0.2', config) self.assertIn(keys['public_key'], config) self.assertIn('AllowedIPs = 172.20.0.0/16', config) def test_multiple_peers(self): """Test managing multiple peers""" # Add first peer peer1_keys = self.wg_manager.generate_peer_keys('peer1') success1 = self.wg_manager.add_peer('peer1', peer1_keys['public_key'], '192.168.1.100') self.assertTrue(success1) # Add second peer peer2_keys = self.wg_manager.generate_peer_keys('peer2') success2 = self.wg_manager.add_peer('peer2', peer2_keys['public_key'], '192.168.1.101') self.assertTrue(success2) # Get peers peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 2) # Remove first peer success3 = self.wg_manager.remove_peer(peer1_keys['public_key']) self.assertTrue(success3) # Check remaining peers peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 1) self.assertEqual(peers[0]['public_key'], peer2_keys['public_key']) def test_config_file_parsing(self): """Test parsing WireGuard configuration file""" # Create a test config file config_file = os.path.join(self.wg_manager.wireguard_dir, 'wg0.conf') test_config = """[Interface] PrivateKey = test_private_key Address = 172.20.0.1/16 ListenPort = 51820 [Peer] PublicKey = peer1_public_key AllowedIPs = 172.20.0.0/16 PersistentKeepalive = 25 [Peer] PublicKey = peer2_public_key AllowedIPs = 172.20.1.0/24 PersistentKeepalive = 30 """ with open(config_file, 'w') as f: f.write(test_config) peers = self.wg_manager.get_peers() self.assertEqual(len(peers), 2) self.assertEqual(peers[0]['public_key'], 'peer1_public_key') self.assertEqual(peers[0]['allowed_ips'], '172.20.0.0/16') self.assertEqual(peers[0]['persistent_keepalive'], 25) self.assertEqual(peers[1]['public_key'], 'peer2_public_key') self.assertEqual(peers[1]['allowed_ips'], '172.20.1.0/24') self.assertEqual(peers[1]['persistent_keepalive'], 30) def test_error_handling(self): """Test error handling in WireGuard operations""" # Test with invalid public key success = self.wg_manager.add_peer('testpeer', 'invalid_key', '192.168.1.100') # Should still return True as it writes to config file self.assertTrue(success) # Test removing non-existent peer success = self.wg_manager.remove_peer('non_existent_key') self.assertTrue(success) # Test updating non-existent peer IP success = self.wg_manager.update_peer_ip('non_existent_key', '192.168.1.200') self.assertFalse(success) if __name__ == '__main__': unittest.main()