84d33aa88c
Added a path guard: if the config file resolves to /tmp/ or a pytest temp dir, _syncconf bails out immediately. Without this, tests calling add_peer/remove_peer with a temp-dir WireGuardManager would connect to the live cell-wireguard container and remove production peers. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
550 lines
23 KiB
Python
550 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
WireGuard Manager for Personal Internet Cell
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import base64
|
|
import socket
|
|
import subprocess
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
|
from base_service_manager import BaseServiceManager
|
|
|
|
try:
|
|
import requests as _requests
|
|
except ImportError:
|
|
_requests = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SERVER_ADDRESS = '10.0.0.1/24'
|
|
SERVER_NETWORK = '10.0.0.0/24'
|
|
DEFAULT_PORT = 51820
|
|
|
|
def _resolve_peer_dns() -> str:
|
|
"""Resolve cell-dns container IP at runtime; fall back to known default."""
|
|
for hostname in ('cell-dns',):
|
|
try:
|
|
return socket.gethostbyname(hostname)
|
|
except OSError:
|
|
pass
|
|
return '172.20.0.3'
|
|
|
|
|
|
class WireGuardManager(BaseServiceManager):
|
|
"""Manages WireGuard VPN configuration and peers"""
|
|
|
|
def __init__(self, data_dir: str = '/app/data', config_dir: str = '/app/config'):
|
|
super().__init__('wireguard', data_dir, config_dir)
|
|
self.wireguard_dir = os.path.join(config_dir, 'wireguard')
|
|
self.keys_dir = os.path.join(data_dir, 'wireguard', 'keys')
|
|
self.peers_dir = os.path.join(data_dir, 'wireguard', 'peers')
|
|
|
|
self.safe_makedirs(self.wireguard_dir)
|
|
self.safe_makedirs(self.keys_dir)
|
|
self.safe_makedirs(os.path.join(self.keys_dir, 'peers'))
|
|
self.safe_makedirs(self.peers_dir)
|
|
|
|
self._ensure_server_keys()
|
|
|
|
# ── Key management ────────────────────────────────────────────────────────
|
|
|
|
@staticmethod
|
|
def _generate_keypair():
|
|
"""Return (private_bytes, public_bytes) using X25519."""
|
|
priv = X25519PrivateKey.generate()
|
|
return priv.private_bytes_raw(), priv.public_key().public_bytes_raw()
|
|
|
|
def _ensure_server_keys(self):
|
|
priv_file = os.path.join(self.keys_dir, 'private.key')
|
|
pub_file = os.path.join(self.keys_dir, 'public.key')
|
|
if not os.path.exists(priv_file):
|
|
try:
|
|
priv_bytes, pub_bytes = self._generate_keypair()
|
|
with open(priv_file, 'wb') as f:
|
|
f.write(priv_bytes)
|
|
with open(pub_file, 'wb') as f:
|
|
f.write(pub_bytes)
|
|
except (PermissionError, OSError):
|
|
pass
|
|
|
|
def get_keys(self) -> Dict[str, str]:
|
|
"""Return server public/private keys as base64 strings."""
|
|
priv_file = os.path.join(self.keys_dir, 'private.key')
|
|
pub_file = os.path.join(self.keys_dir, 'public.key')
|
|
with open(priv_file, 'rb') as f:
|
|
priv = f.read()
|
|
with open(pub_file, 'rb') as f:
|
|
pub = f.read()
|
|
return {
|
|
'private_key': base64.b64encode(priv).decode(),
|
|
'public_key': base64.b64encode(pub).decode(),
|
|
}
|
|
|
|
def generate_peer_keys(self, peer_name: str) -> Dict[str, str]:
|
|
"""Generate a keypair for a peer, save to keys_dir/peers/, return as base64."""
|
|
priv_bytes, pub_bytes = self._generate_keypair()
|
|
priv_b64 = base64.b64encode(priv_bytes).decode()
|
|
pub_b64 = base64.b64encode(pub_bytes).decode()
|
|
|
|
peer_keys_dir = os.path.join(self.keys_dir, 'peers')
|
|
with open(os.path.join(peer_keys_dir, f'{peer_name}_private.key'), 'w') as f:
|
|
f.write(priv_b64)
|
|
with open(os.path.join(peer_keys_dir, f'{peer_name}_public.key'), 'w') as f:
|
|
f.write(pub_b64)
|
|
|
|
return {'private_key': priv_b64, 'public_key': pub_b64, 'peer_name': peer_name}
|
|
|
|
# ── Config generation ─────────────────────────────────────────────────────
|
|
|
|
def get_config(self, interface: str = 'wg0', port: int = DEFAULT_PORT):
|
|
"""Return server config (alias for generate_config, returns dict for API compat)."""
|
|
return {'config': self.generate_config(interface, port)}
|
|
|
|
def generate_config(self, interface: str = 'wg0', port: int = DEFAULT_PORT) -> str:
|
|
"""Return a WireGuard [Interface] config string for the server."""
|
|
keys = self.get_keys()
|
|
ext_ip = self.get_external_ip() or ''
|
|
# Hairpin DNAT: redirect VPN clients targeting the server's public IP
|
|
# to 10.0.0.1 (the VPN interface), avoiding the Docker network loopback.
|
|
hairpin = (
|
|
f'iptables -t nat -A PREROUTING -i %i -d {ext_ip} -j DNAT --to-destination 10.0.0.1; '
|
|
if ext_ip else ''
|
|
)
|
|
hairpin_down = (
|
|
f'iptables -t nat -D PREROUTING -i %i -d {ext_ip} -j DNAT --to-destination 10.0.0.1; '
|
|
if ext_ip else ''
|
|
)
|
|
return (
|
|
f'[Interface]\n'
|
|
f'PrivateKey = {keys["private_key"]}\n'
|
|
f'Address = {SERVER_ADDRESS}\n'
|
|
f'ListenPort = {port}\n'
|
|
f'PostUp = iptables -A FORWARD -i %i -j ACCEPT; '
|
|
f'iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE; '
|
|
f'{hairpin}'
|
|
f'sysctl -q net.ipv4.conf.all.rp_filter=0\n'
|
|
f'PostDown = iptables -D FORWARD -i %i -j ACCEPT; '
|
|
f'iptables -t nat -D POSTROUTING -o eth0 -j MASQUERADE; '
|
|
f'{hairpin_down}'
|
|
f'sysctl -q net.ipv4.conf.all.rp_filter=1\n'
|
|
)
|
|
|
|
def _config_file(self) -> str:
|
|
# linuxserver/wireguard stores configs in wg_confs/
|
|
wg_confs = os.path.join(self.wireguard_dir, 'wg_confs')
|
|
if os.path.isdir(wg_confs):
|
|
return os.path.join(wg_confs, 'wg0.conf')
|
|
return os.path.join(self.wireguard_dir, 'wg0.conf')
|
|
|
|
def _read_config(self) -> str:
|
|
cf = self._config_file()
|
|
if os.path.exists(cf):
|
|
with open(cf, 'r') as f:
|
|
return f.read()
|
|
return self.generate_config()
|
|
|
|
def _write_config(self, content: str):
|
|
with open(self._config_file(), 'w') as f:
|
|
f.write(content)
|
|
self._syncconf()
|
|
|
|
def _syncconf(self):
|
|
"""Sync live WireGuard peers using 'wg set' — never touches [Interface] settings.
|
|
|
|
wg syncconf resets the ListenPort when given a peers-only config,
|
|
breaking client connections. We diff the config file against the live
|
|
interface and add/remove peers individually instead.
|
|
|
|
SAFETY: if the config file is not under the real wireguard config dir
|
|
(e.g. a test temp dir), bail out immediately — never touch the live container.
|
|
"""
|
|
import subprocess, re
|
|
real_conf = self._config_file()
|
|
if '/tmp/' in real_conf or 'pytest' in real_conf:
|
|
logger.debug('_syncconf: skipping — config path looks like a test dir')
|
|
return
|
|
try:
|
|
# Parse desired peers from config file
|
|
content = self._read_config()
|
|
desired: dict = {}
|
|
current_peer = None
|
|
for line in content.splitlines():
|
|
line = line.strip()
|
|
if line == '[Peer]':
|
|
current_peer = {}
|
|
elif current_peer is not None:
|
|
if line.startswith('PublicKey'):
|
|
current_peer['pub'] = line.split('=', 1)[1].strip()
|
|
elif line.startswith('AllowedIPs'):
|
|
current_peer['ips'] = line.split('=', 1)[1].strip()
|
|
elif line.startswith('PersistentKeepalive'):
|
|
current_peer['ka'] = line.split('=', 1)[1].strip()
|
|
elif line == '' and 'pub' in current_peer:
|
|
desired[current_peer['pub']] = current_peer
|
|
current_peer = None
|
|
if current_peer and 'pub' in current_peer:
|
|
desired[current_peer['pub']] = current_peer
|
|
|
|
# Get live peers
|
|
dump = subprocess.run(
|
|
['docker', 'exec', 'cell-wireguard', 'wg', 'show', 'wg0', 'dump'],
|
|
capture_output=True, text=True, timeout=5
|
|
)
|
|
live_pubs = set()
|
|
for line in dump.stdout.splitlines():
|
|
parts = line.split('\t')
|
|
if len(parts) >= 4 and parts[0] not in ('(none)', ''):
|
|
live_pubs.add(parts[0])
|
|
|
|
# Remove peers no longer in config
|
|
for pub in live_pubs - set(desired):
|
|
subprocess.run(
|
|
['docker', 'exec', 'cell-wireguard', 'wg', 'set', 'wg0',
|
|
'peer', pub, 'remove'],
|
|
capture_output=True, timeout=5
|
|
)
|
|
logger.info(f'wg: removed peer {pub[:16]}...')
|
|
|
|
# Add/update peers in config
|
|
for pub, p in desired.items():
|
|
args = ['docker', 'exec', 'cell-wireguard', 'wg', 'set', 'wg0',
|
|
'peer', pub,
|
|
'allowed-ips', p.get('ips', ''),
|
|
'persistent-keepalive', p.get('ka', '25')]
|
|
subprocess.run(args, capture_output=True, timeout=5)
|
|
|
|
logger.info(f'wg set applied: {len(desired)} peers')
|
|
except Exception as e:
|
|
logger.warning(f'_syncconf failed (non-fatal): {e}')
|
|
|
|
# ── Peer CRUD ─────────────────────────────────────────────────────────────
|
|
|
|
def add_peer(self, name: str, public_key: str, endpoint_ip: str,
|
|
allowed_ips: str = SERVER_NETWORK,
|
|
persistent_keepalive: int = 25) -> bool:
|
|
"""Add a [Peer] block to wg0.conf.
|
|
|
|
Server-side AllowedIPs must be the peer's specific VPN IP (/32).
|
|
Passing full-tunnel or split-tunnel CIDRs here would cause the server
|
|
to route all internet or LAN traffic to that peer — breaking everything.
|
|
"""
|
|
import ipaddress
|
|
try:
|
|
# Enforce /32: reject any CIDR wider than a single host
|
|
for cidr in (c.strip() for c in allowed_ips.split(',')):
|
|
try:
|
|
net = ipaddress.ip_network(cidr, strict=False)
|
|
if net.prefixlen < 32 and not cidr.endswith('/32'):
|
|
raise ValueError(
|
|
f"Server-side AllowedIPs must be a /32 host address, got '{cidr}'. "
|
|
"Full/split tunnel CIDRs belong in the CLIENT config only."
|
|
)
|
|
except ValueError as ve:
|
|
raise ve
|
|
|
|
content = self._read_config()
|
|
peer_block = (
|
|
f'\n[Peer]\n'
|
|
f'# {name}\n'
|
|
f'PublicKey = {public_key}\n'
|
|
f'AllowedIPs = {allowed_ips}\n'
|
|
f'PersistentKeepalive = {persistent_keepalive}\n'
|
|
)
|
|
if endpoint_ip:
|
|
peer_block += f'Endpoint = {endpoint_ip}:{DEFAULT_PORT}\n'
|
|
self._write_config(content + peer_block)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f'add_peer failed: {e}')
|
|
return False
|
|
|
|
def remove_peer(self, public_key: str) -> bool:
|
|
"""Remove the [Peer] block matching public_key from wg0.conf."""
|
|
try:
|
|
content = self._read_config()
|
|
# Split on blank lines between blocks
|
|
raw_blocks = ('\n' + content).split('\n\n')
|
|
new_blocks = [
|
|
b for b in raw_blocks
|
|
if not (f'PublicKey = {public_key}' in b and '[Peer]' in b)
|
|
]
|
|
self._write_config('\n\n'.join(new_blocks).lstrip('\n'))
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f'remove_peer failed: {e}')
|
|
return False
|
|
|
|
def get_peers(self) -> List[Dict[str, Any]]:
|
|
"""Parse wg0.conf and return list of peer dicts."""
|
|
content = self._read_config()
|
|
peers = []
|
|
sections = content.split('[Peer]')
|
|
for section in sections[1:]:
|
|
peer: Dict[str, Any] = {}
|
|
for line in section.strip().splitlines():
|
|
line = line.strip()
|
|
if not line or line.startswith('#'):
|
|
continue
|
|
if '=' not in line:
|
|
continue
|
|
key, _, value = line.partition('=')
|
|
key = key.strip().lower().replace(' ', '')
|
|
value = value.strip()
|
|
if key == 'publickey':
|
|
peer['public_key'] = value
|
|
elif key == 'allowedips':
|
|
peer['allowed_ips'] = value
|
|
elif key == 'persistentkeepalive':
|
|
try:
|
|
peer['persistent_keepalive'] = int(value)
|
|
except ValueError:
|
|
peer['persistent_keepalive'] = value
|
|
elif key == 'endpoint':
|
|
peer['endpoint'] = value
|
|
if peer:
|
|
peers.append(peer)
|
|
return peers
|
|
|
|
def update_peer_ip(self, public_key: str, new_ip: str) -> bool:
|
|
"""Update AllowedIPs for the peer with the given public key."""
|
|
content = self._read_config()
|
|
if f'PublicKey = {public_key}' not in content:
|
|
return False
|
|
lines = content.splitlines()
|
|
in_target = False
|
|
new_lines = []
|
|
for line in lines:
|
|
if line.strip() == f'PublicKey = {public_key}':
|
|
in_target = True
|
|
if in_target and line.strip().startswith('AllowedIPs'):
|
|
line = f'AllowedIPs = {new_ip}'
|
|
in_target = False
|
|
new_lines.append(line)
|
|
self._write_config('\n'.join(new_lines))
|
|
return True
|
|
|
|
SPLIT_TUNNEL_IPS = '10.0.0.0/24, 172.20.0.0/16'
|
|
FULL_TUNNEL_IPS = '0.0.0.0/0, ::/0'
|
|
|
|
def get_peer_config(self, peer_name: str, peer_ip: str,
|
|
peer_private_key: str,
|
|
server_endpoint: str = '<SERVER_IP>',
|
|
allowed_ips: str = None) -> str:
|
|
"""Generate a WireGuard client config string (full-tunnel by default)."""
|
|
if allowed_ips is None:
|
|
allowed_ips = self.FULL_TUNNEL_IPS
|
|
server_keys = self.get_keys()
|
|
peer_dns = _resolve_peer_dns()
|
|
endpoint = server_endpoint if ':' in server_endpoint else f'{server_endpoint}:{DEFAULT_PORT}'
|
|
addr = peer_ip if '/' in peer_ip else f'{peer_ip}/32'
|
|
return (
|
|
f'[Interface]\n'
|
|
f'PrivateKey = {peer_private_key}\n'
|
|
f'Address = {addr}\n'
|
|
f'DNS = {peer_dns}\n'
|
|
f'\n'
|
|
f'[Peer]\n'
|
|
f'PublicKey = {server_keys["public_key"]}\n'
|
|
f'AllowedIPs = {allowed_ips}\n'
|
|
f'Endpoint = {endpoint}\n'
|
|
f'PersistentKeepalive = 25\n'
|
|
)
|
|
|
|
# ── External IP & port ────────────────────────────────────────────────────
|
|
|
|
def _ip_cache_file(self) -> str:
|
|
return os.path.join(self.keys_dir, 'external_ip.json')
|
|
|
|
def get_external_ip(self, force_refresh: bool = False) -> Optional[str]:
|
|
"""Detect external IP, caching result for 1 hour."""
|
|
cache_file = self._ip_cache_file()
|
|
if not force_refresh and os.path.exists(cache_file):
|
|
try:
|
|
with open(cache_file) as f:
|
|
data = json.load(f)
|
|
if time.time() - data.get('ts', 0) < 3600:
|
|
return data.get('ip')
|
|
except Exception:
|
|
pass
|
|
|
|
ip = None
|
|
services = [
|
|
'https://api.ipify.org',
|
|
'https://ifconfig.me/ip',
|
|
'https://icanhazip.com',
|
|
]
|
|
if _requests:
|
|
for url in services:
|
|
try:
|
|
resp = _requests.get(url, timeout=5)
|
|
candidate = resp.text.strip()
|
|
if candidate and len(candidate) < 45:
|
|
ip = candidate
|
|
break
|
|
except Exception:
|
|
continue
|
|
|
|
if ip:
|
|
try:
|
|
with open(cache_file, 'w') as f:
|
|
json.dump({'ip': ip, 'ts': time.time()}, f)
|
|
except (PermissionError, OSError):
|
|
pass
|
|
return ip
|
|
|
|
def check_port_open(self, port: int = DEFAULT_PORT) -> bool:
|
|
"""Check if WireGuard is running and listening on the UDP port."""
|
|
# Primary: check if wg0 interface is up (means port IS listening)
|
|
try:
|
|
result = subprocess.run(
|
|
['docker', 'exec', 'cell-wireguard', 'wg', 'show', 'wg0'],
|
|
capture_output=True, text=True, timeout=5,
|
|
)
|
|
if result.returncode == 0 and 'listening port' in result.stdout.lower():
|
|
return True
|
|
except Exception:
|
|
pass
|
|
# Fallback: recent peer handshake confirms external reachability
|
|
try:
|
|
statuses = self.get_all_peer_statuses()
|
|
for st in statuses.values():
|
|
if st.get('online'):
|
|
return True
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
def get_server_config(self) -> Dict[str, Any]:
|
|
"""Return server public key, external IP, endpoint, and port status."""
|
|
keys = self.get_keys()
|
|
external_ip = self.get_external_ip()
|
|
endpoint = f'{external_ip}:{DEFAULT_PORT}' if external_ip else None
|
|
return {
|
|
'public_key': keys['public_key'],
|
|
'external_ip': external_ip,
|
|
'endpoint': endpoint,
|
|
'port': DEFAULT_PORT,
|
|
'port_open': None,
|
|
}
|
|
|
|
def get_peer_status(self, public_key: str) -> Dict[str, Any]:
|
|
"""Return live handshake + transfer stats for a peer from `wg show`."""
|
|
try:
|
|
result = subprocess.run(
|
|
['docker', 'exec', 'cell-wireguard', 'wg', 'show', 'wg0', 'dump'],
|
|
capture_output=True, text=True, timeout=5,
|
|
)
|
|
for line in result.stdout.splitlines():
|
|
parts = line.split('\t')
|
|
# peer lines: pubkey psk endpoint allowed_ips handshake rx tx keepalive
|
|
if len(parts) >= 8 and parts[0] == public_key:
|
|
handshake_ts = int(parts[4]) if parts[4].isdigit() else 0
|
|
now = int(time.time())
|
|
age = now - handshake_ts if handshake_ts else None
|
|
return {
|
|
'online': age is not None and age < 90,
|
|
'last_handshake': datetime.utcfromtimestamp(handshake_ts).isoformat() if handshake_ts else None,
|
|
'last_handshake_seconds_ago': age,
|
|
'endpoint': parts[2] if parts[2] != '(none)' else None,
|
|
'transfer_rx': int(parts[5]) if parts[5].isdigit() else 0,
|
|
'transfer_tx': int(parts[6]) if parts[6].isdigit() else 0,
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f'get_peer_status failed: {e}')
|
|
return {'online': None, 'last_handshake': None, 'transfer_rx': 0, 'transfer_tx': 0}
|
|
|
|
def get_all_peer_statuses(self) -> Dict[str, Any]:
|
|
"""Return {public_key: status_dict} for all known peers from wg show."""
|
|
statuses: Dict[str, Any] = {}
|
|
try:
|
|
result = subprocess.run(
|
|
['docker', 'exec', 'cell-wireguard', 'wg', 'show', 'wg0', 'dump'],
|
|
capture_output=True, text=True, timeout=5,
|
|
)
|
|
now = int(time.time())
|
|
for line in result.stdout.splitlines():
|
|
parts = line.split('\t')
|
|
if len(parts) >= 8:
|
|
pub = parts[0]
|
|
handshake_ts = int(parts[4]) if parts[4].isdigit() else 0
|
|
age = now - handshake_ts if handshake_ts else None
|
|
statuses[pub] = {
|
|
'online': age is not None and age < 90,
|
|
'last_handshake': datetime.utcfromtimestamp(handshake_ts).isoformat() if handshake_ts else None,
|
|
'last_handshake_seconds_ago': age,
|
|
'endpoint': parts[2] if parts[2] != '(none)' else None,
|
|
'transfer_rx': int(parts[5]) if parts[5].isdigit() else 0,
|
|
'transfer_tx': int(parts[6]) if parts[6].isdigit() else 0,
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f'get_all_peer_statuses failed: {e}')
|
|
return statuses
|
|
|
|
# ── Status & connectivity ─────────────────────────────────────────────────
|
|
|
|
def get_status(self) -> Dict[str, Any]:
|
|
"""Return service status by checking whether the Docker container is up."""
|
|
try:
|
|
result = subprocess.run(
|
|
['docker', 'ps', '--filter', 'name=cell-wireguard', '--format', '{{.Names}}'],
|
|
capture_output=True, text=True, timeout=5,
|
|
)
|
|
running = 'cell-wireguard' in result.stdout
|
|
return {
|
|
'running': running,
|
|
'status': 'online' if running else 'offline',
|
|
'interface': 'wg0',
|
|
'ip_info': {'address': SERVER_ADDRESS} if running else {},
|
|
'peers_count': len(self.get_peers()),
|
|
'timestamp': datetime.utcnow().isoformat(),
|
|
}
|
|
except Exception as e:
|
|
return self.handle_error(e, 'get_status')
|
|
|
|
def test_connectivity(self, peer_ip: str) -> Dict[str, Any]:
|
|
"""Ping a peer IP and return results."""
|
|
try:
|
|
result = subprocess.run(
|
|
['ping', '-c', '1', '-W', '2', peer_ip],
|
|
capture_output=True, text=True, timeout=5,
|
|
)
|
|
return {
|
|
'peer_ip': peer_ip,
|
|
'ping_success': result.returncode == 0,
|
|
'ping_output': result.stdout,
|
|
'ping_error': result.stderr,
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
'peer_ip': peer_ip,
|
|
'ping_success': False,
|
|
'ping_output': '',
|
|
'ping_error': str(e),
|
|
}
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
status = self.get_status()
|
|
return {
|
|
'service': 'wireguard',
|
|
'timestamp': datetime.utcnow().isoformat(),
|
|
'status': status.get('status', 'unknown'),
|
|
'peers_count': status.get('peers_count', 0),
|
|
}
|
|
|
|
def restart_service(self) -> bool:
|
|
try:
|
|
result = subprocess.run(
|
|
['docker', 'restart', 'cell-wireguard'],
|
|
capture_output=True, text=True, timeout=30,
|
|
)
|
|
return result.returncode == 0
|
|
except Exception as e:
|
|
logger.error(f'restart_service failed: {e}')
|
|
return False
|