Files
pic/tests/test_wireguard_manager.py
T
roof 28a193e430 Fix ensure_postup_dnat to strip-and-replace all DNAT rules idempotently
_get_dnat_container_ips() used a concatenating docker inspect format that
produced "invalid IP" when containers had multiple network attachments.
The old ensure_postup_dnat appended rather than replacing, so each update
call added a broken duplicate set of rules causing iptables to fail on
startup and tear down wg0 entirely.

Fix _get_dnat_container_ips to use a space separator in the format string
and validate each token as a real IP before accepting it.

Rewrite ensure_postup_dnat with _is_dnat_rule() helper: strips every
managed DNAT/FORWARD rule (any IP, port 53/80) on semicolon-split and
appends a single correct set — fully idempotent regardless of prior state.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-04 06:54:20 -04:00

873 lines
37 KiB
Python

#!/usr/bin/env python3
"""
Unit tests for WireGuardManager class
"""
import sys
from pathlib import Path
# Add api directory to path
api_dir = Path(__file__).parent.parent / 'api'
sys.path.insert(0, str(api_dir))
import unittest
import tempfile
import os
import json
import shutil
import base64
from unittest.mock import patch, MagicMock
from datetime import datetime
# Add parent directory to path for imports
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from wireguard_manager import WireGuardManager
class TestWireGuardManager(unittest.TestCase):
"""Test cases for WireGuardManager class"""
def setUp(self):
"""Set up test environment"""
self.test_dir = tempfile.mkdtemp()
self.data_dir = os.path.join(self.test_dir, 'data')
self.config_dir = os.path.join(self.test_dir, 'config')
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.config_dir, exist_ok=True)
patcher = patch.object(WireGuardManager, '_syncconf', return_value=None)
self.mock_sync = patcher.start()
self.addCleanup(patcher.stop)
# Create WireGuardManager instance
self.wg_manager = WireGuardManager(self.data_dir, self.config_dir)
def tearDown(self):
"""Clean up test environment"""
shutil.rmtree(self.test_dir)
def test_initialization(self):
"""Test WireGuardManager initialization"""
self.assertEqual(self.wg_manager.data_dir, self.data_dir)
self.assertEqual(self.wg_manager.config_dir, self.config_dir)
self.assertTrue(os.path.exists(self.wg_manager.wireguard_dir))
self.assertTrue(os.path.exists(self.wg_manager.keys_dir))
def test_key_generation(self):
"""Test WireGuard key generation"""
# Check if keys were generated
private_key_file = os.path.join(self.wg_manager.keys_dir, 'private.key')
public_key_file = os.path.join(self.wg_manager.keys_dir, 'public.key')
self.assertTrue(os.path.exists(private_key_file))
self.assertTrue(os.path.exists(public_key_file))
# Check key content
with open(private_key_file, 'rb') as f:
private_key = f.read()
self.assertIsInstance(private_key, bytes)
self.assertGreater(len(private_key), 0)
with open(public_key_file, 'rb') as f:
public_key = f.read()
self.assertIsInstance(public_key, bytes)
self.assertGreater(len(public_key), 0)
def test_get_keys(self):
"""Test getting WireGuard keys"""
keys = self.wg_manager.get_keys()
self.assertIn('private_key', keys)
self.assertIn('public_key', keys)
self.assertIsInstance(keys['private_key'], str)
self.assertIsInstance(keys['public_key'], str)
self.assertGreater(len(keys['private_key']), 0)
self.assertGreater(len(keys['public_key']), 0)
def test_generate_peer_keys(self):
"""Test generating keys for a peer"""
peer_keys = self.wg_manager.generate_peer_keys('testpeer')
self.assertIn('private_key', peer_keys)
self.assertIn('public_key', peer_keys)
self.assertIsInstance(peer_keys['private_key'], str)
self.assertIsInstance(peer_keys['public_key'], str)
# Check if peer keys were saved
peer_keys_dir = os.path.join(self.wg_manager.keys_dir, 'peers')
peer_private_file = os.path.join(peer_keys_dir, 'testpeer_private.key')
peer_public_file = os.path.join(peer_keys_dir, 'testpeer_public.key')
self.assertTrue(os.path.exists(peer_private_file))
self.assertTrue(os.path.exists(peer_public_file))
def test_generate_config(self):
"""Test WireGuard configuration generation"""
config = self.wg_manager.generate_config('wg0', 51820)
self.assertIsInstance(config, str)
self.assertIn('[Interface]', config)
self.assertIn('PrivateKey', config)
self.assertIn('Address = 10.0.0.1/24', config)
self.assertIn('ListenPort = 51820', config)
self.assertIn('PostUp', config)
self.assertIn('PostDown', config)
def test_add_peer(self):
"""Test adding a peer — server-side AllowedIPs must be /32."""
peer_keys = self.wg_manager.generate_peer_keys('testpeer')
success = self.wg_manager.add_peer(
'testpeer',
peer_keys['public_key'],
'',
'10.0.0.2/32',
25
)
self.assertTrue(success)
config_file = self.wg_manager._config_file()
self.assertTrue(os.path.exists(config_file))
with open(config_file, 'r') as f:
config = f.read()
self.assertIn('[Peer]', config)
self.assertIn(peer_keys['public_key'], config)
self.assertIn('AllowedIPs = 10.0.0.2/32', config)
self.assertIn('PersistentKeepalive = 25', config)
def test_remove_peer(self):
"""Test removing a peer from WireGuard configuration"""
# Add a peer first
peer_keys = self.wg_manager.generate_peer_keys('testpeer')
self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32')
# Remove the peer
success = self.wg_manager.remove_peer(peer_keys['public_key'])
self.assertTrue(success)
# Check if peer was removed
config_file = self.wg_manager._config_file()
with open(config_file, 'r') as f:
config = f.read()
self.assertNotIn(peer_keys['public_key'], config)
def test_get_peers(self):
"""Test getting list of configured peers"""
# Add a peer first
peer_keys = self.wg_manager.generate_peer_keys('testpeer')
self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32')
peers = self.wg_manager.get_peers()
self.assertIsInstance(peers, list)
self.assertEqual(len(peers), 1)
self.assertIn('public_key', peers[0])
self.assertIn('allowed_ips', peers[0])
self.assertIn('persistent_keepalive', peers[0])
self.assertEqual(peers[0]['public_key'], peer_keys['public_key'])
@patch('subprocess.run')
def test_get_status(self, mock_run):
"""Test getting WireGuard status"""
# Mock WireGuard service running
mock_run.return_value.stdout = 'cell-wireguard\n'
mock_run.return_value.returncode = 0
status = self.wg_manager.get_status()
self.assertTrue(status['running'])
self.assertIn('interface', status)
self.assertIn('ip_info', status)
@patch('subprocess.run')
def test_get_status_not_running(self, mock_run):
"""Test getting WireGuard status when service is not running"""
# Mock WireGuard service not running
mock_run.return_value.stdout = ''
mock_run.return_value.returncode = 0
status = self.wg_manager.get_status()
self.assertFalse(status['running'])
@patch('subprocess.run')
def test_test_connectivity(self, mock_run):
"""Test connectivity testing"""
# Mock successful ping
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = 'PING 192.168.1.100'
mock_run.return_value.stderr = ''
result = self.wg_manager.test_connectivity('192.168.1.100')
self.assertEqual(result['peer_ip'], '192.168.1.100')
self.assertTrue(result['ping_success'])
self.assertIn('192.168.1.100', result['ping_output'])
@patch('subprocess.run')
def test_test_connectivity_failure(self, mock_run):
"""Test connectivity testing with failure"""
# Mock failed ping
mock_run.return_value.returncode = 1
mock_run.return_value.stdout = ''
mock_run.return_value.stderr = 'No route to host'
result = self.wg_manager.test_connectivity('192.168.1.100')
self.assertEqual(result['peer_ip'], '192.168.1.100')
self.assertFalse(result['ping_success'])
self.assertIn('No route to host', result['ping_error'])
def test_update_peer_ip(self):
"""Test updating peer IP address"""
peer_keys = self.wg_manager.generate_peer_keys('testpeer')
self.wg_manager.add_peer('testpeer', peer_keys['public_key'], '', '10.0.0.2/32')
success = self.wg_manager.update_peer_ip(peer_keys['public_key'], '10.0.0.9/32')
self.assertTrue(success)
with open(self.wg_manager._config_file(), 'r') as f:
config = f.read()
self.assertIn('10.0.0.9/32', config)
def test_update_peer_ip_accepts_comma_separated_cidrs(self):
"""update_peer_ip accepts comma-separated CIDRs for exit-relay AllowedIPs."""
peer_keys = self.wg_manager.generate_peer_keys('exitpeer')
# Use add_cell_peer — cell peers have subnet AllowedIPs, not /32
self.wg_manager.add_cell_peer(
'exitpeer', peer_keys['public_key'], '', '10.0.1.0/24')
success = self.wg_manager.update_peer_ip(
peer_keys['public_key'], '10.0.1.0/24, 0.0.0.0/0')
self.assertTrue(success, 'Should accept comma-separated CIDRs')
with open(self.wg_manager._config_file(), 'r') as f:
config = f.read()
self.assertIn('10.0.1.0/24, 0.0.0.0/0', config)
def test_update_peer_ip_rejects_newlines(self):
"""update_peer_ip rejects strings with newlines (injection guard)."""
peer_keys = self.wg_manager.generate_peer_keys('badpeer')
self.wg_manager.add_peer('badpeer', peer_keys['public_key'], '', '10.0.0.2/32')
success = self.wg_manager.update_peer_ip(peer_keys['public_key'], '10.0.0.9/32\nPostUp=evil')
self.assertFalse(success)
def test_get_peer_config(self):
"""Test generating peer client configuration."""
peer_keys = self.wg_manager.generate_peer_keys('testpeer')
keys = self.wg_manager.get_keys()
config = self.wg_manager.get_peer_config('testpeer', '10.0.0.2', peer_keys['private_key'])
self.assertIsInstance(config, str)
self.assertIn('[Interface]', config)
self.assertIn('[Peer]', config)
self.assertIn('PrivateKey', config)
self.assertIn('Address = 10.0.0.2/32', config)
self.assertIn('DNS = 10.0.0.1', config)
self.assertIn(keys['public_key'], config)
self.assertIn('AllowedIPs', config)
def test_multiple_peers(self):
"""Test managing multiple peers"""
peer1_keys = self.wg_manager.generate_peer_keys('peer1')
success1 = self.wg_manager.add_peer('peer1', peer1_keys['public_key'], '', '10.0.0.2/32')
self.assertTrue(success1)
peer2_keys = self.wg_manager.generate_peer_keys('peer2')
success2 = self.wg_manager.add_peer('peer2', peer2_keys['public_key'], '', '10.0.0.3/32')
self.assertTrue(success2)
# Get peers
peers = self.wg_manager.get_peers()
self.assertEqual(len(peers), 2)
# Remove first peer
success3 = self.wg_manager.remove_peer(peer1_keys['public_key'])
self.assertTrue(success3)
# Check remaining peers
peers = self.wg_manager.get_peers()
self.assertEqual(len(peers), 1)
self.assertEqual(peers[0]['public_key'], peer2_keys['public_key'])
def test_config_file_parsing(self):
"""Test parsing WireGuard configuration file"""
# Create a test config file
config_file = os.path.join(self.wg_manager.wireguard_dir, 'wg0.conf')
test_config = """[Interface]
PrivateKey = test_private_key
Address = 172.20.0.1/16
ListenPort = 51820
[Peer]
PublicKey = peer1_public_key
AllowedIPs = 172.20.0.0/16
PersistentKeepalive = 25
[Peer]
PublicKey = peer2_public_key
AllowedIPs = 172.20.1.0/24
PersistentKeepalive = 30
"""
with open(config_file, 'w') as f:
f.write(test_config)
peers = self.wg_manager.get_peers()
self.assertEqual(len(peers), 2)
self.assertEqual(peers[0]['public_key'], 'peer1_public_key')
self.assertEqual(peers[0]['allowed_ips'], '172.20.0.0/16')
self.assertEqual(peers[0]['persistent_keepalive'], 25)
self.assertEqual(peers[1]['public_key'], 'peer2_public_key')
self.assertEqual(peers[1]['allowed_ips'], '172.20.1.0/24')
self.assertEqual(peers[1]['persistent_keepalive'], 30)
def test_error_handling(self):
"""Test error handling in WireGuard operations."""
# Wide CIDR rejected — server-side AllowedIPs must be /32
success = self.wg_manager.add_peer('testpeer', 'invalid_key', '', '172.20.0.0/16')
self.assertFalse(success, "Wide CIDR must be rejected")
# Valid /32 with any key string is accepted (key format not validated at this layer)
success = self.wg_manager.add_peer('testpeer', 'YW55X2tleV9zdHJpbmdfZm9yX3Rlc3RzX3dnMTIzISE=', '', '10.0.0.2/32')
self.assertTrue(success)
# Removing non-existent peer is a no-op, not an error
success = self.wg_manager.remove_peer('non_existent_key')
self.assertTrue(success)
# Updating IP for peer not in config returns False
success = self.wg_manager.update_peer_ip('non_existent_key', '10.0.0.9/32')
self.assertFalse(success)
class TestWireGuardCellPeer(unittest.TestCase):
"""Test add_cell_peer allows subnet CIDRs for site-to-site connections."""
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.data_dir = os.path.join(self.test_dir, 'data')
self.config_dir = os.path.join(self.test_dir, 'config')
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.config_dir, exist_ok=True)
patcher = patch.object(WireGuardManager, '_syncconf', return_value=None)
self.mock_sync = patcher.start()
self.addCleanup(patcher.stop)
self.wg = WireGuardManager(self.data_dir, self.config_dir)
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_add_cell_peer_allows_subnet_cidr(self):
ok = self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24')
self.assertTrue(ok)
content = self.wg._read_config()
self.assertIn('10.1.0.0/24', content)
def test_add_cell_peer_writes_full_endpoint(self):
self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24')
content = self.wg._read_config()
self.assertIn('Endpoint = 5.6.7.8:51821', content)
def test_add_cell_peer_comment_has_cell_prefix(self):
self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24')
content = self.wg._read_config()
self.assertIn('# cell:remote', content)
def test_add_cell_peer_invalid_cidr_returns_false(self):
ok = self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', 'not-a-cidr')
self.assertFalse(ok)
def test_add_cell_peer_can_coexist_with_regular_peers(self):
self.wg.add_peer('alice', 'YWxpY2VwdWJrZXlfZm9yX3Rlc3RzX3dndGVzdDEyMyE=', '', '10.0.0.2/32')
self.wg.add_cell_peer('remote', 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', '5.6.7.8:51821', '10.1.0.0/24')
content = self.wg._read_config()
self.assertIn('YWxpY2VwdWJrZXlfZm9yX3Rlc3RzX3dndGVzdDEyMyE=', content)
self.assertIn('cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=', content)
class TestWireGuardConfigReads(unittest.TestCase):
"""Test that port/address/network are read from wg0.conf, not hardcoded."""
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.data_dir = os.path.join(self.test_dir, 'data')
self.config_dir = os.path.join(self.test_dir, 'config')
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.config_dir, exist_ok=True)
patcher = patch.object(WireGuardManager, '_syncconf', return_value=None)
self.mock_sync = patcher.start()
self.addCleanup(patcher.stop)
self.wg = WireGuardManager(self.data_dir, self.config_dir)
def tearDown(self):
shutil.rmtree(self.test_dir)
def _write_wg_conf(self, port=51820, address='10.0.0.1/24', extra=''):
conf = (
f'[Interface]\n'
f'PrivateKey = dummykey\n'
f'Address = {address}\n'
f'ListenPort = {port}\n'
f'{extra}'
)
cf = self.wg._config_file()
os.makedirs(os.path.dirname(cf), exist_ok=True)
with open(cf, 'w') as f:
f.write(conf)
def test_get_configured_port_reads_from_wg_conf(self):
self._write_wg_conf(port=54321)
self.assertEqual(self.wg._get_configured_port(), 54321)
def test_get_configured_port_fallback_when_no_file(self):
# No wg0.conf exists — fall back to DEFAULT_PORT
self.assertEqual(self.wg._get_configured_port(), 51820)
def test_get_configured_address_reads_from_wg_conf(self):
self._write_wg_conf(address='10.1.0.1/24')
self.assertEqual(self.wg._get_configured_address(), '10.1.0.1/24')
def test_get_configured_network_derives_from_address(self):
self._write_wg_conf(address='10.1.0.1/24')
self.assertEqual(self.wg._get_configured_network(), '10.1.0.0/24')
def test_get_split_tunnel_ips_uses_configured_network(self):
self._write_wg_conf(address='10.1.0.1/24')
split = self.wg.get_split_tunnel_ips()
self.assertIn('10.1.0.0/24', split)
# 172.20.0.0/16 is intentionally excluded — services now use WG server IP via DNAT
self.assertNotIn('172.20.0.0/16', split)
self.assertNotIn('10.0.0.0/24', split)
def test_get_server_config_uses_configured_port(self):
self._write_wg_conf(port=54321)
with patch.object(self.wg, 'get_external_ip', return_value='1.2.3.4'):
cfg = self.wg.get_server_config()
self.assertEqual(cfg['port'], 54321)
self.assertIn(':54321', cfg['endpoint'])
def test_get_server_config_includes_dns_and_split_tunnel(self):
self._write_wg_conf(address='10.2.0.1/24')
with patch.object(self.wg, 'get_external_ip', return_value='1.2.3.4'):
cfg = self.wg.get_server_config()
self.assertIn('dns_ip', cfg)
# dns_ip must be the WG server IP, not the container IP
self.assertEqual(cfg['dns_ip'], '10.2.0.1')
self.assertIn('split_tunnel_ips', cfg)
self.assertIn('10.2.0.0/24', cfg['split_tunnel_ips'])
def test_get_peer_config_uses_configured_port_in_endpoint(self):
self._write_wg_conf(port=54321)
result = self.wg.get_peer_config(
peer_name='alice',
peer_ip='10.0.0.2',
peer_private_key='privkeyalice=',
server_endpoint='5.6.7.8',
)
self.assertIn(':54321', result)
self.assertNotIn(':51820', result)
def test_add_peer_uses_configured_port_in_endpoint(self):
self._write_wg_conf(port=54321)
self.wg.add_peer('alice', 'cHVia2V5YWxpY2VfZm9yX3Rlc3RzX3dpcmVndWFyZCE=', '5.6.7.8', '10.0.0.2/32')
content = self.wg._read_config()
self.assertIn('Endpoint = 5.6.7.8:54321', content)
self.assertNotIn(':51820', content)
class TestWireGuardSysctlAndPortCheck(unittest.TestCase):
"""Tests for sysctl safety, port check, and peer status parsing."""
def setUp(self):
self.test_dir = tempfile.mkdtemp()
patcher = patch.object(WireGuardManager, '_syncconf', return_value=None)
self.mock_sync = patcher.start()
self.addCleanup(patcher.stop)
self.addCleanup(shutil.rmtree, self.test_dir)
self.wg = WireGuardManager(self.test_dir, self.test_dir)
# ── generate_config sysctl safety ────────────────────────────────────────
def test_generate_config_postup_has_nonfatal_sysctl(self):
cfg = self.wg.generate_config()
self.assertIn('sysctl -q net.ipv4.conf.all.rp_filter=0 || true', cfg)
def test_generate_config_postdown_has_nonfatal_sysctl(self):
cfg = self.wg.generate_config()
self.assertIn('sysctl -q net.ipv4.conf.all.rp_filter=1 || true', cfg)
def test_generate_config_has_masquerade(self):
cfg = self.wg.generate_config()
self.assertIn('MASQUERADE', cfg)
def test_generate_config_has_forward_drop_rule(self):
cfg = self.wg.generate_config()
self.assertIn('FORWARD -i %i -j DROP', cfg)
self.assertNotIn('FORWARD -i %i -j ACCEPT', cfg)
def test_generate_config_includes_dns_dnat_in_postup(self):
cfg = self.wg.generate_config()
self.assertIn('--dport 53 -j DNAT', cfg)
self.assertIn('--dport 80 -j DNAT', cfg)
def test_generate_config_postdown_removes_dnat(self):
cfg = self.wg.generate_config()
postdown_line = [l for l in cfg.splitlines() if l.startswith('PostDown')][0]
self.assertIn('--dport 53 -j DNAT', postdown_line)
self.assertIn('--dport 80 -j DNAT', postdown_line)
# ── ensure_postup_dnat ────────────────────────────────────────────────────
def _write_wg_conf_postup(self, address='10.0.0.1/24', extra_postup=''):
import os
wg_dir = os.path.join(self.test_dir, 'wireguard', 'wg_confs')
os.makedirs(wg_dir, exist_ok=True)
conf_path = os.path.join(wg_dir, 'wg0.conf')
postup = (
'iptables -A FORWARD -i %i -j DROP; '
'iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE'
)
if extra_postup:
postup += f'; {extra_postup}'
postdown = (
'iptables -D FORWARD -i %i -j DROP; '
'iptables -t nat -D POSTROUTING -o eth0 -j MASQUERADE'
)
content = (
'[Interface]\n'
f'Address = {address}\n'
f'PostUp = {postup}\n'
f'PostDown = {postdown}\n'
)
with open(conf_path, 'w') as f:
f.write(content)
return conf_path
@patch('subprocess.run')
def test_ensure_postup_dnat_adds_rules_when_missing(self, mock_run):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = '172.20.0.3'
self._write_wg_conf_postup()
changed = self.wg.ensure_postup_dnat()
self.assertTrue(changed)
with open(self.wg._config_file()) as f:
content = f.read()
self.assertIn('--dport 53 -j DNAT', content)
self.assertIn('--dport 80 -j DNAT', content)
@patch('subprocess.run')
def test_ensure_postup_dnat_idempotent_when_rules_present(self, mock_run):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = '172.20.0.3'
self._write_wg_conf_postup()
# First call: writes all 6 DNAT rules
first = self.wg.ensure_postup_dnat()
self.assertTrue(first)
# Second call: rules already correct, no change
second = self.wg.ensure_postup_dnat()
self.assertFalse(second)
def test_ensure_postup_dnat_returns_false_when_no_conf(self):
changed = self.wg.ensure_postup_dnat()
self.assertFalse(changed)
# ── check_port_open ───────────────────────────────────────────────────────
@patch('subprocess.run')
def test_check_port_open_when_wg_interface_up(self, mock_run):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = 'interface: wg0\n listening port: 51820\n'
self.assertTrue(self.wg.check_port_open())
@patch('subprocess.run')
def test_check_port_open_false_when_interface_down(self, mock_run):
# wg show fails (no device), fallback wg show dump also fails
mock_run.return_value.returncode = 1
mock_run.return_value.stdout = ''
self.assertFalse(self.wg.check_port_open())
@patch('subprocess.run')
def test_check_port_open_fallback_to_recent_handshake(self, mock_run):
# First call (wg show wg0): fails — interface not reported as up
# Second call (wg show wg0 dump): returns a peer with recent handshake
import time as _time
now = int(_time.time())
dump_line = f'pubkey\t(none)\t1.2.3.4:51820\t0.0.0.0/0\t{now - 10}\t1000\t2000\t25\n'
def side_effect(*args, **kwargs):
cmd = args[0]
m = MagicMock()
if 'dump' in cmd:
m.returncode = 0
m.stdout = dump_line
else:
m.returncode = 0
m.stdout = 'interface: wg0\n' # no "listening port" text
return m
mock_run.side_effect = side_effect
# "listening port" not in stdout for first call → falls through to dump
# dump has recent handshake → returns True
result = self.wg.check_port_open()
self.assertTrue(result)
@patch('subprocess.run')
def test_check_port_open_wrong_port_returns_false(self, mock_run):
# wg0 is up but listening on 51820 while wg0.conf says 51821 — must return False
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = 'interface: wg0\n listening port: 51820\n'
# Write wg0.conf with a different port so _get_configured_port() returns 51821
cfg_path = os.path.join(self.wg.wireguard_dir, 'wg0.conf')
with open(cfg_path, 'w') as f:
f.write('[Interface]\nListenPort = 51821\nPrivateKey = abc\n')
self.assertFalse(self.wg.check_port_open())
@patch('subprocess.run')
def test_check_port_open_explicit_port_matches(self, mock_run):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = 'interface: wg0\n listening port: 12345\n'
self.assertTrue(self.wg.check_port_open(port=12345))
@patch('subprocess.run')
def test_check_port_open_explicit_port_mismatch(self, mock_run):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = 'interface: wg0\n listening port: 51820\n'
self.assertFalse(self.wg.check_port_open(port=51821))
# ── get_peer_status ───────────────────────────────────────────────────────
@patch('subprocess.run')
def test_get_peer_status_online_with_recent_handshake(self, mock_run):
import time as _time
now = int(_time.time())
pub = 'AAABBBCCC='
dump = (
f'privkey\tserverpub\t51820\toff\n' # interface line (4 fields)
f'{pub}\t(none)\t1.2.3.4:12345\t10.0.0.2/32\t{now-30}\t500\t1000\t25\n'
)
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = dump
st = self.wg.get_peer_status(pub)
self.assertTrue(st['online'])
self.assertIsNotNone(st['last_handshake'])
self.assertLessEqual(st['last_handshake_seconds_ago'], 35)
@patch('subprocess.run')
def test_get_peer_status_offline_with_old_handshake(self, mock_run):
import time as _time
now = int(_time.time())
pub = 'AAABBBCCC='
dump = f'{pub}\t(none)\t(none)\t10.0.0.2/32\t{now - 300}\t0\t0\t25\n'
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = dump
st = self.wg.get_peer_status(pub)
self.assertFalse(st['online'])
@patch('subprocess.run')
def test_get_peer_status_not_found_returns_none_online(self, mock_run):
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = ''
st = self.wg.get_peer_status('NOTEXIST=')
self.assertIsNone(st['online'])
@patch('subprocess.run')
def test_get_peer_status_no_handshake_yet(self, mock_run):
pub = 'AAABBBCCC='
dump = f'{pub}\t(none)\t(none)\t10.0.0.2/32\t0\t0\t0\t25\n'
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = dump
st = self.wg.get_peer_status(pub)
self.assertFalse(st['online'])
self.assertIsNone(st['last_handshake'])
# ── get_all_peer_statuses ─────────────────────────────────────────────────
@patch('subprocess.run')
def test_get_all_peer_statuses_parses_multiple_peers(self, mock_run):
import time as _time
now = int(_time.time())
pub1 = 'PUB1KEY='
pub2 = 'PUB2KEY='
dump = (
f'privkey\tserverpub\t51820\toff\n'
f'{pub1}\t(none)\t1.1.1.1:1000\t10.0.0.2/32\t{now-20}\t100\t200\t25\n'
f'{pub2}\t(none)\t(none)\t10.0.0.3/32\t{now-200}\t0\t0\t25\n'
)
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = dump
statuses = self.wg.get_all_peer_statuses()
self.assertIn(pub1, statuses)
self.assertIn(pub2, statuses)
self.assertTrue(statuses[pub1]['online'])
self.assertFalse(statuses[pub2]['online'])
@patch('subprocess.run')
def test_get_all_peer_statuses_empty_when_interface_down(self, mock_run):
mock_run.return_value.returncode = 1
mock_run.return_value.stdout = ''
statuses = self.wg.get_all_peer_statuses()
self.assertEqual(statuses, {})
@patch('subprocess.run')
def test_get_all_peer_statuses_skips_interface_line(self, mock_run):
# Interface line has only 4 tab-separated fields — must not appear as a peer
dump = 'privkey\tserverpub\t51820\toff\n'
mock_run.return_value.returncode = 0
mock_run.return_value.stdout = dump
statuses = self.wg.get_all_peer_statuses()
self.assertEqual(statuses, {})
class TestAddCellPeerSubnetOverlap(unittest.TestCase):
"""Verify that add_cell_peer rejects a vpn_subnet that overlaps the local WG network."""
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.data_dir = os.path.join(self.test_dir, 'data')
self.config_dir = os.path.join(self.test_dir, 'config')
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.config_dir, exist_ok=True)
patcher = patch.object(WireGuardManager, '_syncconf', return_value=None)
self.mock_sync = patcher.start()
self.addCleanup(patcher.stop)
self.wg = WireGuardManager(self.data_dir, self.config_dir)
# Write a known wg0.conf so _get_configured_network() returns 10.0.0.0/24
self._write_wg_conf(address='10.0.0.1/24')
def tearDown(self):
shutil.rmtree(self.test_dir)
def _write_wg_conf(self, address='10.0.0.1/24', port=51820):
conf = (
f'[Interface]\n'
f'PrivateKey = dummykey\n'
f'Address = {address}\n'
f'ListenPort = {port}\n'
)
cf = self.wg._config_file()
os.makedirs(os.path.dirname(cf), exist_ok=True)
with open(cf, 'w') as f:
f.write(conf)
# Public key is 44 chars ending in '=' — required by validation in add_cell_peer
_CELL_PUBKEY = 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE='
def test_add_cell_peer_overlapping_subnet_returns_false(self):
"""vpn_subnet that exactly matches the local WG network must be rejected."""
# local is 10.0.0.0/24; remote is also 10.0.0.0/24 — clear overlap
ok = self.wg.add_cell_peer(
'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.0.0.0/24'
)
self.assertFalse(ok)
def test_add_cell_peer_partially_overlapping_subnet_returns_false(self):
"""A remote subnet that contains the local network (e.g. /16 ⊃ /24) is rejected."""
# 10.0.0.0/16 contains 10.0.0.0/24 → overlaps
ok = self.wg.add_cell_peer(
'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.0.0.0/16'
)
self.assertFalse(ok)
def test_add_cell_peer_non_overlapping_subnet_accepted(self):
"""A remote subnet distinct from the local WG network must be accepted."""
# local is 10.0.0.0/24; remote is 10.0.1.0/24 — no overlap
ok = self.wg.add_cell_peer(
'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.0.1.0/24'
)
self.assertTrue(ok)
def test_add_cell_peer_no_overlap_different_class_a(self):
"""A completely different address space is accepted."""
# local is 10.0.0.0/24; remote is 192.168.5.0/24 — no overlap
ok = self.wg.add_cell_peer(
'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '192.168.5.0/24'
)
self.assertTrue(ok)
def test_add_cell_peer_overlap_check_uses_configured_network(self):
"""When wg0.conf says 172.16.0.1/12, overlapping that range is rejected."""
self._write_wg_conf(address='172.16.0.1/12')
ok = self.wg.add_cell_peer(
'remote', self._CELL_PUBKEY, '5.6.7.8:51821', '172.16.0.0/12'
)
self.assertFalse(ok)
class TestCellRoutes(unittest.TestCase):
"""Tests for _ensure_cell_route and sync_cell_routes."""
_CELL_PUBKEY = 'cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE='
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.data_dir = os.path.join(self.test_dir, 'data')
self.config_dir = os.path.join(self.test_dir, 'config')
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(self.config_dir, exist_ok=True)
patcher = patch.object(WireGuardManager, '_syncconf', return_value=None)
self.mock_sync = patcher.start()
self.addCleanup(patcher.stop)
self.wg = WireGuardManager(self.data_dir, self.config_dir)
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_ensure_cell_route_noop_in_test_dir(self):
"""_ensure_cell_route must not call subprocess when config is in /tmp (test env)."""
with patch('subprocess.run') as mock_run:
self.wg._ensure_cell_route('10.1.0.0/24')
mock_run.assert_not_called()
def test_sync_cell_routes_noop_in_test_dir(self):
"""sync_cell_routes must not call subprocess when config is in /tmp (test env)."""
with patch('subprocess.run') as mock_run:
self.wg.sync_cell_routes()
mock_run.assert_not_called()
def test_ensure_cell_route_calls_ip_route_add(self):
"""Outside test dirs, _ensure_cell_route calls docker exec ip route add."""
with patch.object(self.wg, '_config_file', return_value='/app/config/wireguard/wg_confs/wg0.conf'):
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)
self.wg._ensure_cell_route('10.1.0.0/24')
mock_run.assert_called_once()
cmd = mock_run.call_args[0][0]
self.assertIn('ip', cmd)
self.assertIn('route', cmd)
self.assertIn('add', cmd)
self.assertIn('10.1.0.0/24', cmd)
self.assertIn('wg0', cmd)
def test_sync_cell_routes_finds_cell_peers_in_config(self):
"""sync_cell_routes parses wg0.conf and adds routes for cell peers only."""
conf = (
'[Interface]\nPrivateKey = dummykey\nAddress = 10.0.0.1/24\nListenPort = 51820\n\n'
'[Peer]\n# cell:remote\nPublicKey = cmVtb3RlcHVia2V5X2Zvcl90ZXN0c193Z3Rlc3QxMiE=\n'
'AllowedIPs = 10.1.0.0/24\nPersistentKeepalive = 25\n\n'
'[Peer]\n# alice\nPublicKey = YWxpY2VwdWJrZXlfZm9yX3Rlc3RzX3dndGVzdDEyMyE=\n'
'AllowedIPs = 10.0.0.2/32\nPersistentKeepalive = 25\n'
)
with patch.object(self.wg, '_config_file', return_value='/app/config/wireguard/wg_confs/wg0.conf'):
with patch.object(self.wg, '_read_config', return_value=conf):
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)
self.wg.sync_cell_routes()
calls = [c[0][0] for c in mock_run.call_args_list]
subnets = [c for c in calls if '10.1.0.0/24' in c]
non_cell = [c for c in calls if '10.0.0.2/32' in c]
self.assertTrue(len(subnets) >= 1, 'expected route add for cell peer subnet')
self.assertEqual(len(non_cell), 0, 'should not add route for regular peer')
def test_add_cell_peer_triggers_ensure_cell_route(self):
"""add_cell_peer calls _ensure_cell_route after writing config."""
with patch.object(self.wg, '_ensure_cell_route') as mock_route:
self.wg.add_cell_peer('remote', self._CELL_PUBKEY, '5.6.7.8:51821', '10.1.0.0/24')
mock_route.assert_called_once_with('10.1.0.0/24')
if __name__ == '__main__':
unittest.main()