diff --git a/api/app.py b/api/app.py index a3732db..656b7d5 100644 --- a/api/app.py +++ b/api/app.py @@ -14,9 +14,11 @@ Provides REST API endpoints for managing: import os import io import json +import stat import zipfile import shutil import logging +import secrets from datetime import datetime from flask import Flask, request, jsonify, current_app, send_file, session from flask_cors import CORS @@ -32,7 +34,7 @@ import contextvars API_START_TIME = time.time() from network_manager import NetworkManager -from wireguard_manager import WireGuardManager +from wireguard_manager import WireGuardManager, _resolve_peer_dns from peer_registry import PeerRegistry from email_manager import EmailManager from calendar_manager import CalendarManager @@ -107,11 +109,33 @@ logger = logging.getLogger('picell') # Flask app setup app = Flask(__name__) -CORS(app) +CORS(app, + supports_credentials=True, + origins=['http://localhost', 'http://localhost:5173', 'http://localhost:8081', + 'http://127.0.0.1', 'http://127.0.0.1:5173', 'http://127.0.0.1:8081']) # Development mode flag app.config['DEVELOPMENT_MODE'] = True # Set to True for development, False for production -app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', os.urandom(32)) + +# Persist SECRET_KEY so sessions survive API restarts +SECRET_KEY_FILE = os.path.join(os.environ.get('DATA_DIR', '/app/data'), '.flask_secret_key') +if os.environ.get('SECRET_KEY'): + _flask_secret = os.environ['SECRET_KEY'].encode() if isinstance(os.environ['SECRET_KEY'], str) else os.environ['SECRET_KEY'] +elif os.path.exists(SECRET_KEY_FILE) and os.path.getsize(SECRET_KEY_FILE) > 0: + with open(SECRET_KEY_FILE, 'rb') as _skf: + _flask_secret = _skf.read() +else: + _flask_secret = os.urandom(32) + try: + os.makedirs(os.path.dirname(SECRET_KEY_FILE), exist_ok=True) + _skf_fd = os.open(SECRET_KEY_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(_skf_fd, 'wb') as _skf: + _skf.write(_flask_secret) + except OSError as _e: + logger.warning(f"Could not persist SECRET_KEY to disk: {_e}") +app.config['SECRET_KEY'] = _flask_secret +app.config['SESSION_COOKIE_HTTPONLY'] = True +app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # Initialize enhanced components config_manager = ConfigManager( @@ -183,13 +207,29 @@ def enforce_auth(): # Always allow non-API paths and auth namespace if not path.startswith('/api/') or path.startswith('/api/auth/'): return None - # Only enforce when auth_manager has been properly initialised and seeded + # Only enforce when auth_manager has been properly initialised and seeded. + # When the user store is empty (file missing or unreadable — typical in + # unit tests and fresh installs), bypass enforcement so pre-auth test + # suites continue to work. 503 is only returned when the users file + # exists and is readable but contains no accounts (explicit misconfiguration). try: from auth_manager import AuthManager as _AuthManager if not isinstance(auth_manager, _AuthManager): return None users = auth_manager.list_users() if not users: + # Only fail closed when the auth file is readable but empty — + # that's an explicit misconfiguration. If the file is missing or + # unreadable (test env, wrong host path, permission denied), bypass + # so pre-auth test suites continue to work. + users_file = getattr(auth_manager, '_users_file', None) + if users_file: + try: + with open(users_file, 'r') as _f: + _f.read(1) + return jsonify({'error': 'Authentication not configured. Set admin password first.'}), 503 + except (PermissionError, FileNotFoundError, OSError): + return None return None except Exception: return None @@ -206,6 +246,34 @@ def enforce_auth(): return None +@app.before_request +def check_csrf(): + """Double-submit CSRF protection for state-changing API requests. + + Applies to POST/PUT/DELETE/PATCH on /api/* paths, excluding /api/auth/*. + Skipped entirely when app.config['TESTING'] is True so unit tests remain + unaffected without needing to set CSRF headers. + """ + if app.config.get('TESTING'): + return None + if request.method not in ('POST', 'PUT', 'DELETE', 'PATCH'): + return None + path = request.path + if not path.startswith('/api/') or path.startswith('/api/auth/'): + return None + token_session = session.get('csrf_token') + if not token_session: + # Session predates CSRF tokens (existing login) — issue a token now so + # the next request can carry it. Don't block this request; the client + # couldn't have known the token yet. + session['csrf_token'] = secrets.token_hex(32) + return None + token_header = request.headers.get('X-CSRF-Token') + if not token_header or token_header != token_session: + return jsonify({'error': 'CSRF token missing or invalid'}), 403 + return None + + @app.after_request def log_request(response): ctx = request_context.get({}) @@ -246,7 +314,8 @@ def _apply_startup_enforcement(): try: peers = peer_registry.list_peers() firewall_manager.apply_all_peer_rules(peers) - firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) logger.info(f"Applied enforcement rules for {len(peers)} peers on startup") except Exception as e: logger.warning(f"Startup enforcement failed (non-fatal): {e}") @@ -418,20 +487,16 @@ def is_local_request(): ip = _ipa.ip_address(addr.strip()) if ip.is_loopback: return True - # RFC-1918 private ranges - for _rfc in ('10.0.0.0/8', '172.16.0.0/12', '192.168.0.0/16'): - if ip in _ipa.ip_network(_rfc): - return True + # Only trust loopback and Docker bridge (172.16.0.0/12). + # Deliberately excludes 10.0.0.0/8 (WireGuard peer subnet) and + # 192.168.0.0/16 (LAN) — VPN peers must not access local-only endpoints. + if ip in _ipa.ip_network('172.16.0.0/12'): + return True # Any subnet the container is directly attached to (handles non-RFC-1918 # Docker bridge networks such as 172.0.0.0/24). for _net in _local_subnets(): if ip in _net: return True - # Configured cell ip_range (WireGuard peer subnet) - _cell = config_manager.configs.get('_identity', {}).get( - 'ip_range', os.environ.get('CELL_IP_RANGE', '172.20.0.0/16')) - if ip in _ipa.ip_network(_cell, strict=False): - return True except Exception: pass return False @@ -537,21 +602,31 @@ def update_config(): identity_keys = {'cell_name', 'domain', 'ip_range', 'wireguard_port'} identity_updates = {k: v for k, v in data.items() if k in identity_keys} - # Validate cell_name — must be non-empty and at most 255 characters (DNS limit) + # Validate cell_name and domain — block injection characters while + # allowing the full range of valid hostname/domain characters. + import re as _re_cfg + # cell_name: hostname component — letters, digits, hyphens only (no dots) + _CELL_NAME_RE = _re_cfg.compile(r'^[a-zA-Z0-9][a-zA-Z0-9-]{0,254}$') + # domain: may include dots for multi-label names (e.g. home.lan) + _DOMAIN_RE = _re_cfg.compile(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,254}$') + if 'cell_name' in identity_updates: v = str(identity_updates['cell_name']) - if len(v) > 255: - return jsonify({'error': 'cell_name must be 255 characters or fewer'}), 400 if not v: return jsonify({'error': 'cell_name cannot be empty'}), 400 + if len(v) > 255: + return jsonify({'error': 'cell_name must be 255 characters or fewer'}), 400 + if not _CELL_NAME_RE.match(v): + return jsonify({'error': 'Invalid cell_name: use only letters, digits, hyphens'}), 400 - # Validate domain — must be non-empty and at most 255 characters (DNS limit) if 'domain' in identity_updates: v = str(identity_updates['domain']) - if len(v) > 255: - return jsonify({'error': 'domain must be 255 characters or fewer'}), 400 if not v: return jsonify({'error': 'domain cannot be empty'}), 400 + if len(v) > 255: + return jsonify({'error': 'domain must be 255 characters or fewer'}), 400 + if not _DOMAIN_RE.match(v): + return jsonify({'error': 'Invalid domain: use only letters, digits, hyphens, dots'}), 400 # Validate ip_range — must be a valid CIDR within an RFC-1918 range if 'ip_range' in identity_updates: @@ -686,7 +761,7 @@ def update_config(): _cur_id = config_manager.configs.get('_identity', {}) _cur_range = _cur_id.get('ip_range', os.environ.get('CELL_IP_RANGE', '172.20.0.0/16')) _cur_name = _cur_id.get('cell_name', os.environ.get('CELL_NAME', 'mycell')) - _ip_domain.write_caddyfile(_cur_range, _cur_name, domain, '/app/config/caddy/Caddyfile') + _ip_domain.write_caddyfile(_cur_range, _cur_name, domain, '/app/config-caddy/Caddyfile') _set_pending_restart( [f'domain changed to {domain}'], ['dns', 'caddy'], @@ -705,7 +780,7 @@ def update_config(): _cur_id2 = config_manager.configs.get('_identity', {}) _cur_range2 = _cur_id2.get('ip_range', os.environ.get('CELL_IP_RANGE', '172.20.0.0/16')) _cur_domain2 = identity_updates.get('domain') or _cur_id2.get('domain', os.environ.get('CELL_DOMAIN', 'cell')) - _ip_name.write_caddyfile(_cur_range2, new_name, _cur_domain2, '/app/config/caddy/Caddyfile') + _ip_name.write_caddyfile(_cur_range2, new_name, _cur_domain2, '/app/config-caddy/Caddyfile') _set_pending_restart( [f'cell_name changed to {new_name}'], ['dns'], @@ -731,7 +806,7 @@ def update_config(): ip_utils.write_env_file(new_range, env_file, _collect_service_ports(config_manager.configs)) # Regenerate Caddyfile with new VIPs ip_utils.write_caddyfile(new_range, cur_cell_name, cur_domain, - '/app/config/caddy/Caddyfile') + '/app/config-caddy/Caddyfile') # Mark ALL containers as needing restart; network_recreate signals that # docker compose down is required before up (Docker can't change subnet in-place) _set_pending_restart( @@ -934,7 +1009,7 @@ def cancel_pending_config(): if cur_cell_name and old_cell_name and cur_cell_name != old_cell_name: network_manager.apply_cell_name(cur_cell_name, old_cell_name, reload=False) - _ip_revert.write_caddyfile(_range, _cell, _dom, '/app/config/caddy/Caddyfile') + _ip_revert.write_caddyfile(_range, _cell, _dom, '/app/config-caddy/Caddyfile') _clear_pending_restart() return jsonify({'message': 'Pending changes discarded'}) @@ -966,9 +1041,6 @@ def apply_pending_config(): containers = pending.get('containers', ['*']) - # Clear pending flag before we restart so it shows cleared after new containers start - _clear_pending_restart() - # Check if the IP range (network subnet) is changing — Docker cannot modify an # existing network's subnet in-place, so we need `down` + `up` in that case. needs_network_recreate = pending.get('network_recreate', False) @@ -981,6 +1053,9 @@ def apply_pending_config(): # API container itself, killing this background thread mid-operation. # Spawn an independent helper container (same image as cell-api) that has docker # CLI and survives cell-api being stopped/recreated. + # Clear pending flag now — the helper runs fire-and-forget and we cannot track + # its exit code from within the API process (it may restart us). + _clear_pending_restart() if needs_network_recreate: helper_script = ( f'sleep 2' @@ -1015,6 +1090,8 @@ def apply_pending_config(): ) else: # Specific containers only — API is not affected, run directly from here. + # Only clear the pending flag after the subprocess exits with code 0 so that + # if the compose command fails the UI still shows changes as pending. def _do_apply(): import time as _time import subprocess as _subprocess @@ -1031,6 +1108,7 @@ def apply_pending_config(): logger.error(f"docker compose up failed: {result.stderr.strip()}") else: logger.info(f'docker compose up completed for: {containers}') + _clear_pending_restart() threading.Thread(target=_do_apply, daemon=False).start() @@ -1690,7 +1768,7 @@ def get_server_config(): logger.error(f"Error getting server config: {e}") return jsonify({"error": str(e)}), 500 -@app.route('/api/wireguard/refresh-ip', methods=['POST']) +@app.route('/api/wireguard/refresh-ip', methods=['GET', 'POST']) def refresh_external_ip(): try: ip = wireguard_manager.get_external_ip(force_refresh=True) @@ -1710,12 +1788,13 @@ def apply_wireguard_enforcement(): try: peers = peer_registry.list_peers() firewall_manager.apply_all_peer_rules(peers) - firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) return jsonify({'ok': True, 'peers': len(peers)}) except Exception as e: return jsonify({'error': str(e)}), 500 -@app.route('/api/wireguard/check-port', methods=['POST']) +@app.route('/api/wireguard/check-port', methods=['GET', 'POST']) def check_wireguard_port(): try: port_open = wireguard_manager.check_port_open() @@ -1835,7 +1914,10 @@ def add_peer(): if len(password) < 10: return jsonify({"error": "password must be at least 10 characters"}), 400 - assigned_ip = data.get('ip') or _next_peer_ip() + try: + assigned_ip = data.get('ip') or _next_peer_ip() + except ValueError as e: + return jsonify({'error': str(e)}), 409 # Validate service_access if provided _valid_services = {'calendar', 'files', 'mail', 'webdav'} @@ -1882,33 +1964,51 @@ def add_peer(): 'config_needs_reinstall': False, } - success = peer_registry.add_peer(peer_info) - if success: - # Add peer to WireGuard server config (non-fatal if WG is not running) + peer_added_to_registry = False + try: + # Step 1: Add to registry + success = peer_registry.add_peer(peer_info) + if not success: + # Registry rejected (already exists) — rollback provisioned accounts + for svc in ('files', 'calendar', 'email', 'auth'): + try: + if svc == 'files': + file_manager.delete_user(peer_name) + elif svc == 'calendar': + calendar_manager.delete_calendar_user(peer_name) + elif svc == 'email': + email_manager.delete_email_user(peer_name, _configured_domain()) + elif svc == 'auth': + auth_manager.delete_user(peer_name) + except Exception: + pass + return jsonify({"error": f"Peer {peer_name} already exists"}), 400 + peer_added_to_registry = True + + # Step 2: Firewall rules (critical) + firewall_manager.apply_peer_rules(peer_info['ip'], peer_info) + + # Step 3: Add peer to WireGuard server config (non-fatal if WG is not running) wg_allowed = f"{assigned_ip}/32" if '/' not in assigned_ip else assigned_ip try: wireguard_manager.add_peer(peer_name, data['public_key'], endpoint_ip='', allowed_ips=wg_allowed) except Exception as wg_err: logger.warning(f"Peer {peer_name}: WireGuard server config update failed (non-fatal): {wg_err}") - # Apply server-side enforcement immediately - firewall_manager.apply_peer_rules(peer_info['ip'], peer_info) - firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain()) + + # Step 4: Update DNS rules + firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) return jsonify({"message": f"Peer {peer_name} added successfully", "ip": assigned_ip}), 201 - else: - # Registry rejected (already exists) — rollback provisioned accounts - for svc in ('files', 'calendar', 'email', 'auth'): + + except Exception as e: + # Rollback registry entry if we got past that step + if peer_added_to_registry: try: - if svc == 'files': - file_manager.delete_user(peer_name) - elif svc == 'calendar': - calendar_manager.delete_calendar_user(peer_name) - elif svc == 'email': - email_manager.delete_email_user(peer_name) - elif svc == 'auth': - auth_manager.delete_user(peer_name) + peer_registry.remove_peer(peer_name) except Exception: pass - return jsonify({"error": f"Peer {peer_name} already exists"}), 400 + logger.error(f"Error adding peer {peer_name}: {e}") + return jsonify({'error': str(e)}), 500 except Exception as e: logger.error(f"Error adding peer: {e}") @@ -1941,7 +2041,8 @@ def update_peer(peer_name): updated_peer = peer_registry.get_peer(peer_name) if updated_peer: firewall_manager.apply_peer_rules(updated_peer['ip'], updated_peer) - firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) result = {"message": f"Peer {peer_name} updated", "config_changed": config_changed} return jsonify(result) else: @@ -1974,7 +2075,8 @@ def remove_peer(peer_name): if success: if peer_ip: firewall_manager.clear_peer_rules(peer_ip) - firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) # Remove peer from WireGuard server config (non-fatal) if peer_pubkey: try: @@ -1983,7 +2085,7 @@ def remove_peer(peer_name): logger.warning(f"Peer {peer_name}: WireGuard removal failed (non-fatal): {wg_err}") # Clean up all provisioned service accounts (best-effort) for _cleanup in [ - lambda: email_manager.delete_email_user(peer_name), + lambda: email_manager.delete_email_user(peer_name, _configured_domain()), lambda: calendar_manager.delete_calendar_user(peer_name), lambda: file_manager.delete_user(peer_name), lambda: auth_manager.delete_user(peer_name), @@ -2094,8 +2196,13 @@ def create_email_user(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = email_manager.create_user(data) - return jsonify(result) + username = data.get('username') + domain = data.get('domain') or _configured_domain() + password = data.get('password') + if not username or not password: + return jsonify({"error": "Missing required fields: username, password"}), 400 + result = email_manager.create_email_user(username, domain, password) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating email user: {e}") return jsonify({"error": str(e)}), 500 @@ -2104,8 +2211,9 @@ def create_email_user(): def delete_email_user(username): """Delete email user.""" try: - result = email_manager.delete_user(username) - return jsonify(result) + domain = request.args.get('domain') or _configured_domain() + result = email_manager.delete_email_user(username, domain) + return jsonify({"deleted": result}) except Exception as e: logger.error(f"Error deleting email user: {e}") return jsonify({"error": str(e)}), 500 @@ -2170,8 +2278,12 @@ def create_calendar_user(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = calendar_manager.create_user(data) - return jsonify(result) + username = data.get('username') + password = data.get('password') + if not username or not password: + return jsonify({"error": "Missing required fields: username, password"}), 400 + result = calendar_manager.create_calendar_user(username, password) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating calendar user: {e}") return jsonify({"error": str(e)}), 500 @@ -2180,8 +2292,8 @@ def create_calendar_user(): def delete_calendar_user(username): """Delete calendar user.""" try: - result = calendar_manager.delete_user(username) - return jsonify(result) + result = calendar_manager.delete_calendar_user(username) + return jsonify({"deleted": result}) except Exception as e: logger.error(f"Error deleting calendar user: {e}") return jsonify({"error": str(e)}), 500 @@ -2193,8 +2305,17 @@ def create_calendar(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = calendar_manager.create_calendar(data) - return jsonify(result) + username = data.get('username') + calendar_name = data.get('name') or data.get('calendar_name') + if not username or not calendar_name: + return jsonify({"error": "Missing required fields: username, name"}), 400 + result = calendar_manager.create_calendar( + username, + calendar_name, + description=data.get('description', ''), + color=data.get('color', '#4285f4'), + ) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating calendar: {e}") return jsonify({"error": str(e)}), 500 @@ -2205,8 +2326,13 @@ def add_calendar_event(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = calendar_manager.add_event(data) - return jsonify(result) + username = data.get('username') + calendar_name = data.get('calendar_name') or data.get('calendar') + if not username or not calendar_name: + return jsonify({"error": "Missing required fields: username, calendar_name"}), 400 + event_data = {k: v for k, v in data.items() if k not in ('username', 'calendar_name', 'calendar')} + result = calendar_manager.add_event(username, calendar_name, event_data) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error adding calendar event: {e}") return jsonify({"error": str(e)}), 500 @@ -2260,8 +2386,12 @@ def create_file_user(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = file_manager.create_user(data) - return jsonify(result) + username = data.get('username') + password = data.get('password') + if not username or not password: + return jsonify({"error": "Missing required fields: username, password"}), 400 + result = file_manager.create_user(username, password) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating file user: {e}") return jsonify({"error": str(e)}), 500 @@ -2283,8 +2413,12 @@ def create_folder(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = file_manager.create_folder(data) - return jsonify(result) + username = data.get('username') + folder_path = data.get('folder_path') or data.get('path') + if not username or not folder_path: + return jsonify({"error": "Missing required fields: username, folder_path"}), 400 + result = file_manager.create_folder(username, folder_path) + return jsonify({"created": result}) except ValueError as e: return jsonify({"error": str(e)}), 400 except Exception as e: @@ -2309,12 +2443,13 @@ def upload_file(username): try: if 'file' not in request.files: return jsonify({"error": "No file provided"}), 400 - + file = request.files['file'] - path = request.form.get('path', '') - - result = file_manager.upload_file(username, file, path) - return jsonify(result) + path = request.form.get('path', '') or file.filename or '' + file_data = file.read() + + result = file_manager.upload_file(username, path, file_data) + return jsonify({"uploaded": result}) except ValueError as e: return jsonify({"error": str(e)}), 400 except Exception as e: @@ -2442,9 +2577,15 @@ def remove_nat_rule(rule_id): def add_peer_route(): """Add peer route.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_peer_route(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + peer_name = data.get('peer_name') + peer_ip = data.get('peer_ip') + allowed_networks = data.get('allowed_networks', []) + route_type = data.get('route_type', 'lan') + if not peer_name or not peer_ip: + return jsonify({"error": "Missing required fields: peer_name, peer_ip"}), 400 + result = routing_manager.add_peer_route(peer_name, peer_ip, allowed_networks, route_type) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding peer route: {e}") return jsonify({"error": str(e)}), 500 @@ -2463,9 +2604,13 @@ def remove_peer_route(peer_name): def add_exit_node(): """Add exit node.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_exit_node(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + peer_name = data.get('peer_name') + peer_ip = data.get('peer_ip') + if not peer_name or not peer_ip: + return jsonify({"error": "Missing required fields: peer_name, peer_ip"}), 400 + result = routing_manager.add_exit_node(peer_name, peer_ip, data.get('allowed_domains')) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding exit node: {e}") return jsonify({"error": str(e)}), 500 @@ -2474,9 +2619,14 @@ def add_exit_node(): def add_bridge_route(): """Add bridge route.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_bridge_route(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + source_peer = data.get('source_peer') + target_peer = data.get('target_peer') + allowed_networks = data.get('allowed_networks', []) + if not source_peer or not target_peer: + return jsonify({"error": "Missing required fields: source_peer, target_peer"}), 400 + result = routing_manager.add_bridge_route(source_peer, target_peer, allowed_networks) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding bridge route: {e}") return jsonify({"error": str(e)}), 500 @@ -2485,9 +2635,13 @@ def add_bridge_route(): def add_split_route(): """Add split route.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_split_route(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + network = data.get('network') + exit_peer = data.get('exit_peer') + if not network or not exit_peer: + return jsonify({"error": "Missing required fields: network, exit_peer"}), 400 + result = routing_manager.add_split_route(network, exit_peer, data.get('fallback_peer')) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding split route: {e}") return jsonify({"error": str(e)}), 500 @@ -2985,6 +3139,12 @@ def create_container(): volumes = data.get('volumes', {}) command = data.get('command', '') ports = data.get('ports', {}) + if volumes: + allowed_prefixes = ('/home/roof/pic/data/', '/home/roof/pic/config/', '/tmp/') + for host_path in volumes.keys(): + resolved = os.path.realpath(str(host_path)) + if not any(resolved.startswith(p) for p in allowed_prefixes): + return jsonify({'error': f'Volume mount not allowed: {host_path}'}), 403 result = container_manager.create_container( image=data['image'], name=name, @@ -3086,14 +3246,27 @@ def peer_dashboard(): peer_ip = peer.get('ip', '') allowed_ips = f"{peer_ip.split('/')[0]}/32" if peer_ip else '' + domain = _configured_domain() + _svc_url_map = { + 'calendar': f'http://calendar.{domain}', + 'files': f'http://files.{domain}', + 'mail': f'http://mail.{domain}', + 'webdav': f'http://webdav.{domain}', + } + service_urls = { + svc: _svc_url_map[svc] + for svc in peer.get('service_access', []) + if svc in _svc_url_map + } return jsonify({ - 'peer_name': peer_name, + 'name': peer_name, 'ip': peer_ip, 'service_access': peer.get('service_access', []), + 'service_urls': service_urls, 'online': wg_stats.get('online'), - 'rx_bytes': wg_stats.get('transfer_rx', 0), - 'tx_bytes': wg_stats.get('transfer_tx', 0), + 'transfer_rx': wg_stats.get('transfer_rx', 0), + 'transfer_tx': wg_stats.get('transfer_tx', 0), 'last_handshake': wg_stats.get('last_handshake'), 'allowed_ips': peer.get('allowed_ips', allowed_ips), }) @@ -3112,32 +3285,51 @@ def peer_services(): server_public_key = '' wg_port = 51820 + server_endpoint = '' try: server_public_key = wireguard_manager.get_keys().get('public_key', '') wg_port = config_manager.configs.get('_identity', {}).get('wireguard_port', 51820) + srv = wireguard_manager.get_server_config() + server_endpoint = srv.get('endpoint') or '' except Exception: pass + wg_config = '' + peer_private_key = peer.get('private_key', '') + if peer_private_key: + try: + internet_access = peer.get('internet_access', True) + allowed_ips = wireguard_manager.FULL_TUNNEL_IPS if internet_access else wireguard_manager.get_split_tunnel_ips() + wg_config = wireguard_manager.get_peer_config( + peer_name=peer_name, + peer_ip=peer_ip, + peer_private_key=peer_private_key, + server_endpoint=server_endpoint, + allowed_ips=allowed_ips, + ) + except Exception: + pass + return jsonify({ + 'username': peer_name, 'wireguard': { 'ip': peer_ip, 'server_public_key': server_public_key, 'endpoint_port': wg_port, - 'dns': '10.0.0.1', + 'dns': _resolve_peer_dns(), + 'config': wg_config, }, 'email': { - 'username': f'{peer_name}@{domain}', - 'imap_host': f'mail.{domain}', - 'smtp_host': f'mail.{domain}', - 'imap_port': 993, - 'smtp_port': 587, + 'address': f'{peer_name}@{domain}', + 'smtp': {'host': f'mail.{domain}', 'port': 587}, + 'imap': {'host': f'mail.{domain}', 'port': 993}, }, 'caldav': { - 'url': f'http://radicale.{domain}:5232', + 'url': f'http://calendar.{domain}', 'username': peer_name, }, - 'webdav': { - 'url': f'http://webdav.{domain}', + 'files': { + 'url': f'http://files.{domain}', 'username': peer_name, }, }) diff --git a/api/auth_routes.py b/api/auth_routes.py index 10c427c..f2a268f 100644 --- a/api/auth_routes.py +++ b/api/auth_routes.py @@ -8,6 +8,7 @@ after instantiation. A ``require_auth(role=None)`` decorator is also exported so individual routes can opt-in to specific role requirements. """ +import secrets from functools import wraps from flask import Blueprint, request, jsonify, session @@ -80,11 +81,13 @@ def login(): session['username'] = user['username'] session['role'] = user.get('role') session['peer_name'] = user.get('peer_name') + session['csrf_token'] = secrets.token_hex(32) return jsonify({ 'username': user['username'], 'role': user.get('role'), 'peer_name': user.get('peer_name'), 'must_change_password': bool(user.get('must_change_password', False)), + 'csrf_token': session['csrf_token'], }) @@ -143,6 +146,16 @@ def admin_reset_password(): return jsonify({'ok': True}) +@auth_bp.route('/csrf-token', methods=['GET']) +def get_csrf_token(): + """Return the current session's CSRF token, generating one if absent.""" + token = session.get('csrf_token') + if not token: + token = secrets.token_hex(32) + session['csrf_token'] = token + return jsonify({'csrf_token': token}) + + @auth_bp.route('/users', methods=['GET']) @require_auth('admin') def list_users(): diff --git a/api/base_service_manager.py b/api/base_service_manager.py index 142074f..42c349f 100644 --- a/api/base_service_manager.py +++ b/api/base_service_manager.py @@ -65,10 +65,20 @@ class BaseServiceManager(ABC): return [f"Error reading logs: {str(e)}"] def restart_service(self) -> bool: - """Restart service - default implementation""" + """Restart service - default implementation. + + Delegates to _restart_container() using self.container_name when set, + otherwise falls back to self.service_name. Subclasses with a known + container name should set self.container_name in their __init__ or + override this method entirely. + """ try: - self.logger.info(f"Restarting {self.service_name} service") - return True + name = getattr(self, 'container_name', None) or self.service_name + if not name: + self.logger.warning("restart_service: no container name available; skipping restart") + return False + self.logger.info(f"Restarting {self.service_name} service via container '{name}'") + return self._restart_container(name) except Exception as e: self.logger.error(f"Error restarting {self.service_name}: {e}") return False diff --git a/api/calendar_manager.py b/api/calendar_manager.py index ee6dd2c..c21deac 100644 --- a/api/calendar_manager.py +++ b/api/calendar_manager.py @@ -255,9 +255,14 @@ class CalendarManager(BaseServiceManager): return False # Create new user + # SECURITY: Do NOT persist the plaintext password here. The calendar + # password is the same as the user's VPN auth password and storing + # it in plain JSON would leak every user credential if this file is + # read. Auth verification goes through auth_manager; the actual + # CalDAV/CardDAV auth is handled by the cell-radicale container + # (htpasswd file). This JSON is metadata only. new_user = { 'username': username, - 'password': password, # In production, this should be hashed 'calendars_count': 0, 'events_count': 0, 'created_at': datetime.utcnow().isoformat(), @@ -267,11 +272,14 @@ class CalendarManager(BaseServiceManager): users.append(new_user) self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Create user directory user_dir = os.path.join(self.calendar_data_dir, 'users', username) self.safe_makedirs(user_dir) - + logger.info(f"Created calendar user: {username}") return True except Exception as e: @@ -288,13 +296,16 @@ class CalendarManager(BaseServiceManager): if user.get('username') == username: del users[i] self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Remove user directory user_dir = os.path.join(self.calendar_data_dir, 'users', username) if os.path.exists(user_dir): import shutil shutil.rmtree(user_dir) - + logger.info(f"Deleted calendar user: {username}") return True @@ -446,11 +457,31 @@ class CalendarManager(BaseServiceManager): except Exception as e: return self.handle_error(e, "get_metrics") + def _sync_users_to_cell_config(self): + """Best-effort sync of the calendar user list into cell_config.json via ConfigManager. + + Only safe metadata (no passwords) is written. Failures are logged as + warnings so they never block the per-service operation that triggered them. + """ + try: + from config_manager import ConfigManager + cm = ConfigManager() + _SENSITIVE = {'password', 'hashed_password', 'password_hash'} + safe_users = [ + {k: v for k, v in u.items() if k not in _SENSITIVE} + for u in self._load_users() + ] + existing = cm.get_service_config('calendar') + existing['users'] = safe_users + cm.update_service_config('calendar', existing) + except Exception as e: + self.logger.warning(f"Failed to sync calendar users to cell_config.json: {e}") + def restart_service(self) -> bool: - """Restart calendar service""" + """Restart calendar service (restarts the cell-radicale Docker container).""" try: logger.info('Calendar service restart requested') - return True + return self._restart_container('cell-radicale') except Exception as e: logger.error(f'Failed to restart calendar service: {e}') return False diff --git a/api/config_manager.py b/api/config_manager.py index ae12ad2..763dad6 100644 --- a/api/config_manager.py +++ b/api/config_manager.py @@ -14,6 +14,9 @@ from typing import Dict, List, Optional, Any from pathlib import Path import logging +# The Caddyfile lives on a separate volume mount from the rest of config +LIVE_CADDYFILE = os.environ.get('CADDYFILE_PATH', '/app/config-caddy/Caddyfile') + logger = logging.getLogger(__name__) class ConfigManager: @@ -216,7 +219,7 @@ class ConfigManager: env_file = Path(os.environ.get('ENV_FILE', '/app/.env')) extra = [ - (config_dir / 'caddy' / 'Caddyfile', 'Caddyfile'), + (Path(LIVE_CADDYFILE), 'Caddyfile'), (config_dir / 'dns' / 'Corefile', 'Corefile'), (env_file, '.env'), ] @@ -288,7 +291,7 @@ class ConfigManager: env_file = Path(os.environ.get('ENV_FILE', '/app/.env')) restore_map = [ - (backup_path / 'Caddyfile', config_dir / 'caddy' / 'Caddyfile'), + (backup_path / 'Caddyfile', Path(LIVE_CADDYFILE)), (backup_path / 'Corefile', config_dir / 'dns' / 'Corefile'), (backup_path / '.env', env_file), ] diff --git a/api/email_manager.py b/api/email_manager.py index dd5a0c5..5663567 100644 --- a/api/email_manager.py +++ b/api/email_manager.py @@ -299,11 +299,16 @@ class EmailManager(BaseServiceManager): return False # Create new user + # SECURITY: Do NOT persist the plaintext password here. The email + # password is the same as the user's VPN auth password and storing + # it in plain JSON would leak every user credential if this file + # is read. Auth verification goes through auth_manager; the actual + # mailbox auth is handled by the cell-mail container (Dovecot), + # which has its own credential store. This JSON is metadata only. new_user = { 'username': username, 'domain': domain, 'email': f'{username}@{domain}', - 'password': password, # In production, this should be hashed 'quota_limit': quota_limit, 'quota_used': 0, 'created_at': datetime.utcnow().isoformat(), @@ -313,11 +318,14 @@ class EmailManager(BaseServiceManager): users.append(new_user) self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Create user mailbox directory mailbox_dir = os.path.join(self.email_data_dir, 'mailboxes', f'{username}@{domain}') self.safe_makedirs(mailbox_dir) - + logger.info(f"Created email user: {username}@{domain}") return True except Exception as e: @@ -334,13 +342,16 @@ class EmailManager(BaseServiceManager): if user.get('username') == username and user.get('domain') == domain: del users[i] self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Remove user mailbox directory mailbox_dir = os.path.join(self.email_data_dir, 'mailboxes', f'{username}@{domain}') if os.path.exists(mailbox_dir): import shutil shutil.rmtree(mailbox_dir) - + logger.info(f"Deleted email user: {username}@{domain}") return True @@ -408,11 +419,35 @@ class EmailManager(BaseServiceManager): except Exception as e: return self.handle_error(e, "get_metrics") + def _sync_users_to_cell_config(self): + """Best-effort sync of the email user list into cell_config.json via ConfigManager. + + Only safe metadata (no passwords) is written. Failures are logged as + warnings so they never block the per-service operation that triggered them. + """ + try: + # Import here to avoid circular imports and to tolerate environments + # where config_manager is not on sys.path. + from config_manager import ConfigManager + cm = ConfigManager() + # Build safe user list: strip any sensitive keys that should not + # land in the shared config file. + _SENSITIVE = {'password', 'hashed_password', 'password_hash'} + safe_users = [ + {k: v for k, v in u.items() if k not in _SENSITIVE} + for u in self._load_users() + ] + existing = cm.get_service_config('email') + existing['users'] = safe_users + cm.update_service_config('email', existing) + except Exception as e: + self.logger.warning(f"Failed to sync email users to cell_config.json: {e}") + def restart_service(self) -> bool: - """Restart email service""" + """Restart email service (restarts the cell-mail Docker container).""" try: logger.info('Email service restart requested') - return True + return self._restart_container('cell-mail') except Exception as e: logger.error(f'Failed to restart email service: {e}') return False diff --git a/api/file_manager.py b/api/file_manager.py index 256f1ba..2d555df 100644 --- a/api/file_manager.py +++ b/api/file_manager.py @@ -14,6 +14,7 @@ from datetime import datetime from typing import Dict, List, Optional, Tuple, Any import shutil import hashlib +import bcrypt from base_service_manager import BaseServiceManager logger = logging.getLogger(__name__) @@ -103,9 +104,18 @@ umask = 022 if not username or not password: logger.error("Username and password must not be empty") return False + # Validate username — prevents path traversal in user_dir join below and + # injection of newlines / colons into the htpasswd-format auth file. + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"create_user: invalid username {username!r}") + return False try: - # Create user directory - user_dir = os.path.join(self.files_dir, username) + # Create user directory (containment check) + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"create_user: path traversal for username {username!r}") + return False os.makedirs(user_dir, exist_ok=True) # Create default folders @@ -115,8 +125,12 @@ umask = 022 # Add user to auth file auth_file = os.path.join(self.webdav_dir, 'users') - # Generate password hash - password_hash = hashlib.sha256(password.encode()).hexdigest() + # Generate bcrypt hash; convert $2b$ -> $2y$ for Apache htpasswd compatibility + # (bytemark/webdav is Apache-based; htpasswd-bcrypt uses $2y$ prefix). + bcrypt_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') + if bcrypt_hash.startswith('$2b$'): + bcrypt_hash = '$2y$' + bcrypt_hash[4:] + password_hash = bcrypt_hash with open(auth_file, 'a') as f: f.write(f"{username}:{password_hash}\n") @@ -133,6 +147,10 @@ umask = 022 if not username: logger.error("Username must not be empty") return False + # Validate username before any auth-file rewrite or filesystem ops + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"delete_user: invalid username {username!r}") + return False try: # Remove from auth file auth_file = os.path.join(self.webdav_dir, 'users') @@ -145,8 +163,13 @@ umask = 022 if not line.startswith(f"{username}:"): f.write(line) - # Remove user directory - user_dir = os.path.join(self.files_dir, username) + # Remove user directory — containment check prevents + # username='..' or 'foo/../../etc' from escaping files_dir. + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"delete_user: path traversal for username {username!r}") + return False if os.path.exists(user_dir): shutil.rmtree(user_dir) @@ -460,8 +483,15 @@ umask = 022 if not username or not backup_path: logger.error("Username and backup_path must not be empty") return False + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"backup_user_files: invalid username {username!r}") + return False try: - user_dir = os.path.join(self.files_dir, username) + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"backup_user_files: path traversal for username {username!r}") + return False if os.path.exists(user_dir): shutil.make_archive(backup_path, 'zip', user_dir) @@ -480,8 +510,15 @@ umask = 022 if not username or not backup_path: logger.error("Username and backup_path must not be empty") return False + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"restore_user_files: invalid username {username!r}") + return False try: - user_dir = os.path.join(self.files_dir, username) + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"restore_user_files: path traversal for username {username!r}") + return False # Remove existing user directory if os.path.exists(user_dir): diff --git a/api/firewall_manager.py b/api/firewall_manager.py index 01572c5..bf55c31 100644 --- a/api/firewall_manager.py +++ b/api/firewall_manager.py @@ -114,19 +114,32 @@ def _delete_rule(chain: str, rule_args: List[str]) -> None: # --------------------------------------------------------------------------- def _peer_comment(peer_ip: str) -> str: - return f'pic-peer-{peer_ip.replace(".", "-")}' + # SECURITY: append a non-numeric, non-dash suffix so peer comments cannot + # be substrings of one another. Without this, the comment for 10.0.0.1 + # ('pic-peer-10-0-0-1') is a prefix of 10.0.0.10..19 and a naive + # substring match would delete unrelated peers' rules. + return f'pic-peer-{peer_ip.replace(".", "-")}/32' def clear_peer_rules(peer_ip: str) -> None: """Remove all FORWARD rules tagged with this peer's IP via iptables-save/restore.""" comment = _peer_comment(peer_ip) + # SECURITY: match the comment as a complete --comment token, not a + # substring. iptables-save renders comments as `--comment ""` (or + # occasionally without quotes), so we anchor on the surrounding quotes / + # whitespace. Even with the unique /32 suffix in _peer_comment, we keep + # exact-token matching so a future change to the comment format cannot + # silently re-introduce the substring-deletion bug. + comment_re = re.compile( + rf'--comment\s+["\']?{re.escape(comment)}["\']?(\s|$)' + ) try: # Dump rules, strip matching lines, restore — atomic and order-stable save = _wg_exec(['iptables-save']) if save.returncode != 0: return lines = save.stdout.splitlines() - filtered = [l for l in lines if comment not in l] + filtered = [l for l in lines if not comment_re.search(l)] if len(filtered) == len(lines): return # nothing to remove restore_input = '\n'.join(filtered) + '\n' @@ -243,11 +256,15 @@ def _build_acl_block(blocked_peers_by_service: Dict[str, List[str]], def generate_corefile(peers: List[Dict[str, Any]], corefile_path: str = COREFILE_PATH, - domain: str = 'cell') -> bool: + domain: str = 'cell', + cell_links: Optional[List[Dict[str, Any]]] = None) -> bool: """ Rewrite the CoreDNS Corefile with per-peer ACL rules and reload plugin. The file is written to corefile_path (API-side path mapped into CoreDNS container). domain: the configured cell domain (e.g. 'cell', 'dev') — must match zone file names. + cell_links: optional list of cell-to-cell DNS forwarding entries, each a dict with + 'domain' and 'dns_ip' keys (same shape as CellLinkManager.list_connections()). + When non-empty, a forwarding stanza is appended for each entry. """ try: # Collect which peers block which services @@ -275,8 +292,25 @@ def generate_corefile(peers: List[Dict[str, Any]], corefile_path: str = COREFILE health }} -{primary_zone_block} -""" +{primary_zone_block}""" + + # Append cell-to-cell DNS forwarding stanzas if provided + if cell_links: + for link in cell_links: + link_domain = link.get('domain', '') + link_dns_ip = link.get('dns_ip', '') + if not link_domain or not link_dns_ip: + continue + corefile += ( + f'\n{link_domain} {{\n' + f' forward . {link_dns_ip}\n' + f' cache\n' + f' log\n' + f'}}\n' + ) + else: + corefile += '\n' + # local.{domain} block intentionally omitted: /data/local.zone does not exist # and CoreDNS logs errors on every reload for a missing zone file. os.makedirs(os.path.dirname(corefile_path), exist_ok=True) @@ -309,9 +343,10 @@ def reload_coredns() -> bool: def apply_all_dns_rules(peers: List[Dict[str, Any]], corefile_path: str = COREFILE_PATH, - domain: str = 'cell') -> bool: - """Regenerate Corefile and reload CoreDNS.""" - ok = generate_corefile(peers, corefile_path, domain) + domain: str = 'cell', + cell_links: Optional[List[Dict[str, Any]]] = None) -> bool: + """Regenerate Corefile (including any cell-to-cell forwarding stanzas) and reload CoreDNS.""" + ok = generate_corefile(peers, corefile_path, domain, cell_links) if ok: reload_coredns() return ok diff --git a/api/ip_utils.py b/api/ip_utils.py index 2f98cd9..0b1dacb 100644 --- a/api/ip_utils.py +++ b/api/ip_utils.py @@ -189,6 +189,10 @@ http://api.{domain} {{ reverse_proxy cell-api:3000 }} +http://webui.{domain} {{ + reverse_proxy cell-webui:80 +}} + # Catch-all for direct IP / localhost :80 {{ handle /api/* {{ @@ -200,12 +204,12 @@ http://api.{domain} {{ }} """ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) - tmp = path + '.tmp' - with open(tmp, 'w') as f: + # Write in-place (same inode) so Docker bind-mounted files see the update. + # os.replace() changes the inode which breaks file bind-mounts inside containers. + with open(path, 'w') as f: f.write(content) f.flush() os.fsync(f.fileno()) - os.replace(tmp, path) return True except Exception: return False diff --git a/api/network_manager.py b/api/network_manager.py index a64a76e..5370cb6 100644 --- a/api/network_manager.py +++ b/api/network_manager.py @@ -29,8 +29,28 @@ class NetworkManager(BaseServiceManager): def update_dns_zone(self, zone_name: str, records: List[Dict]) -> bool: """Update DNS zone file with new records""" + # Validate zone_name — must be a safe DNS label, no path traversal + if not isinstance(zone_name, str) or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', zone_name): + logger.error(f"update_dns_zone: invalid zone_name {zone_name!r}") + return False try: zone_file = os.path.join(self.dns_zones_dir, f'{zone_name}.zone') + # Containment check: resolved zone_file must be inside dns_zones_dir + real_dir = os.path.realpath(self.dns_zones_dir) + real_zone = os.path.realpath(zone_file) + if not (real_zone == real_dir or real_zone.startswith(real_dir + os.sep)): + logger.error(f"update_dns_zone: path traversal attempt for zone {zone_name!r}") + return False + # Validate every record's name and value to prevent zone-file injection + for rec in records: + rname = rec.get('name', '') + rvalue = rec.get('value', '') + if rname and not re.match(r'^[a-zA-Z0-9_.*-]{1,253}$', str(rname)): + logger.error(f"update_dns_zone: invalid record name {rname!r}") + return False + if rvalue and not re.match(r'^[a-zA-Z0-9._: -]{1,512}$', str(rvalue)): + logger.error(f"update_dns_zone: invalid record value {rvalue!r}") + return False # Create zone file content content = self._generate_zone_content(zone_name, records) @@ -84,6 +104,16 @@ class NetworkManager(BaseServiceManager): def add_dns_record(self, zone: str, name: str, record_type: str, value: str, ttl: int = 3600) -> bool: """Add a DNS record to a zone""" + # Validate zone, name, and value to prevent injection / path traversal + if not isinstance(zone, str) or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', zone): + logger.error(f"add_dns_record: invalid zone {zone!r}") + return False + if not isinstance(name, str) or not re.match(r'^[a-zA-Z0-9_.*-]{1,253}$', name): + logger.error(f"add_dns_record: invalid name {name!r}") + return False + if not isinstance(value, str) or not re.match(r'^[a-zA-Z0-9._: -]{1,512}$', value): + logger.error(f"add_dns_record: invalid value {value!r}") + return False try: # Load existing records records = self._load_dns_records(zone) @@ -150,13 +180,21 @@ class NetworkManager(BaseServiceManager): return {'restarted': restarted, 'warnings': warnings} def _build_dns_records(self, cell_name: str, ip_range: str) -> List[Dict]: - """Build the standard set of DNS A records for the given subnet.""" + """Build the standard set of DNS A records for the given subnet. + + All user-facing names resolve to the Caddy reverse proxy (caddy IP) so + the Host header is passed through and Caddy routes based on it. + Exception: calendar/files/mail/webdav use dedicated virtual IPs so that + iptables per-service firewall rules can target them by destination IP. + api and webui also go through Caddy — they don't have their own VIPs and + their containers don't serve HTTP on port 80. + """ import ip_utils ips = ip_utils.get_service_ips(ip_range) return [ {'name': cell_name, 'type': 'A', 'value': ips['caddy']}, - {'name': 'api', 'type': 'A', 'value': ips['api']}, - {'name': 'webui', 'type': 'A', 'value': ips['webui']}, + {'name': 'api', 'type': 'A', 'value': ips['caddy']}, + {'name': 'webui', 'type': 'A', 'value': ips['caddy']}, {'name': 'calendar', 'type': 'A', 'value': ips['vip_calendar']}, {'name': 'files', 'type': 'A', 'value': ips['vip_files']}, {'name': 'mail', 'type': 'A', 'value': ips['vip_mail']}, @@ -497,58 +535,75 @@ class NetworkManager(BaseServiceManager): warnings.append(f"cell_name DNS update failed: {e}") return {'restarted': restarted, 'warnings': warnings} + def _load_cell_links(self) -> List[Dict[str, Any]]: + """Load cell_links.json from the data directory (written by CellLinkManager).""" + links_file = os.path.join(self.data_dir, 'cell_links.json') + if os.path.exists(links_file): + try: + with open(links_file) as f: + return json.load(f) + except Exception: + return [] + return [] + def add_cell_dns_forward(self, domain: str, dns_ip: str) -> Dict[str, Any]: - """Append a CoreDNS forwarding block for a remote cell's domain.""" + """Register a CoreDNS forwarding entry for a remote cell's domain. + + Validates inputs, then rebuilds the entire Corefile via + firewall_manager.apply_all_dns_rules() so that no existing stanza is + silently wiped. Does NOT write the Corefile directly. + """ + import ipaddress + import firewall_manager as fm restarted = [] warnings = [] + # Validate dns_ip — newlines/garbage would inject arbitrary CoreDNS directives try: - corefile = os.path.join(self.config_dir, 'dns', 'Corefile') - if not os.path.exists(corefile): - warnings.append('Corefile not found') - return {'restarted': restarted, 'warnings': warnings} - with open(corefile) as f: - content = f.read() - marker = f'# cell:{domain}' - if marker in content: - return {'restarted': restarted, 'warnings': warnings} # already present - forward_block = ( - f'\n{marker}\n' - f'{domain} {{\n' - f' forward . {dns_ip}\n' - f' log\n' - f'}}\n' - ) - with open(corefile, 'a') as f: - f.write(forward_block) - self._reload_dns_service() + ipaddress.ip_address(dns_ip) + except (ValueError, TypeError): + warnings.append(f'add_cell_dns_forward: invalid dns_ip {dns_ip!r}') + return {'restarted': restarted, 'warnings': warnings} + # Validate domain — reject newlines, braces, spaces, and any non-DNS chars + if (not isinstance(domain, str) + or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', domain) + or any(c in domain for c in ('\n', '\r', '{', '}', ' ', '\t'))): + warnings.append(f'add_cell_dns_forward: invalid domain {domain!r}') + return {'restarted': restarted, 'warnings': warnings} + try: + # Build the full forwarding list: existing links + new entry (deduped by domain) + existing_links = self._load_cell_links() + # The new entry may not yet be in cell_links.json (CellLinkManager saves after + # calling us), so we merge it in here. + merged = [l for l in existing_links if l.get('domain') != domain] + merged.append({'domain': domain, 'dns_ip': dns_ip}) + + corefile_path = os.path.join(self.config_dir, 'dns', 'Corefile') + # Peers list is empty here; the full peer list is used by the periodic + # apply_all_dns_rules() call from app.py. We only need to persist the + # forwarding stanza without disturbing whatever peer ACLs are in the file. + fm.apply_all_dns_rules([], corefile_path=corefile_path, cell_links=merged) restarted.append('cell-dns (reloaded)') except Exception as e: warnings.append(f'add_cell_dns_forward failed: {e}') return {'restarted': restarted, 'warnings': warnings} def remove_cell_dns_forward(self, domain: str) -> Dict[str, Any]: - """Remove a CoreDNS forwarding block for a remote cell's domain.""" - import re + """Unregister a CoreDNS forwarding entry for a remote cell's domain. + + Rebuilds the entire Corefile via firewall_manager.apply_all_dns_rules() + with the named domain excluded. Does NOT write the Corefile directly. + """ + import firewall_manager as fm restarted = [] warnings = [] try: - corefile = os.path.join(self.config_dir, 'dns', 'Corefile') - if not os.path.exists(corefile): - return {'restarted': restarted, 'warnings': warnings} - with open(corefile) as f: - content = f.read() - marker = f'# cell:{domain}' - if marker not in content: - return {'restarted': restarted, 'warnings': warnings} - new_content = re.sub( - rf'\n# cell:{re.escape(domain)}\n{re.escape(domain)}\s*\{{[^}}]*\}}\n', - '', - content, - flags=re.DOTALL, - ) - with open(corefile, 'w') as f: - f.write(new_content) - self._reload_dns_service() + existing_links = self._load_cell_links() + # Exclude the domain being removed; CellLinkManager will also remove it + # from cell_links.json after this call returns. + remaining = [l for l in existing_links if l.get('domain') != domain] + + corefile_path = os.path.join(self.config_dir, 'dns', 'Corefile') + fm.apply_all_dns_rules([], corefile_path=corefile_path, cell_links=remaining) restarted.append('cell-dns (reloaded)') except Exception as e: warnings.append(f'remove_cell_dns_forward failed: {e}') diff --git a/api/peer_registry.py b/api/peer_registry.py index 941c484..86a14c2 100644 --- a/api/peer_registry.py +++ b/api/peer_registry.py @@ -1,341 +1,360 @@ -#!/usr/bin/env python3 -""" -Peer Registry for Personal Internet Cell -Handles peer registration and management -""" - -import json -import os -import logging -from threading import RLock -from datetime import datetime -from typing import Dict, List, Any, Optional -from base_service_manager import BaseServiceManager - -logger = logging.getLogger(__name__) - -class PeerRegistry(BaseServiceManager): - """Manages peer registration and management""" - - def __init__(self, data_dir: str = '/app/data', config_dir: str = '/app/config'): - super().__init__('peer_registry', data_dir, config_dir) - self.lock = RLock() - self.peers = [] - self.peers_file = os.path.join(data_dir, 'peers.json') - self._load_peers() - - def get_status(self) -> Dict[str, Any]: - """Get peer registry status""" - try: - with self.lock: - status = { - 'running': True, - 'status': 'online', - 'peers_count': len(self.peers), - 'active_peers': len([p for p in self.peers if p.get('active', True)]), - 'inactive_peers': len([p for p in self.peers if not p.get('active', True)]), - 'last_updated': datetime.utcnow().isoformat(), - 'timestamp': datetime.utcnow().isoformat() - } - - return status - except Exception as e: - return self.handle_error(e, "get_status") - - def test_connectivity(self) -> Dict[str, Any]: - """Test peer registry connectivity""" - try: - # Test file system access - fs_test = self._test_filesystem_access() - - # Test peer data integrity - integrity_test = self._test_data_integrity() - - # Test peer operations - operations_test = self._test_peer_operations() - - results = { - 'filesystem_access': fs_test, - 'data_integrity': integrity_test, - 'peer_operations': operations_test, - 'success': fs_test.get('success', False) and integrity_test.get('success', False), - 'timestamp': datetime.utcnow().isoformat() - } - - return results - except Exception as e: - return self.handle_error(e, "test_connectivity") - - def _test_filesystem_access(self) -> Dict[str, Any]: - """Test filesystem access for peer data""" - try: - # Test if we can read/write to the peers file - test_peer = { - 'peer': 'test_peer', - 'ip': '192.168.1.100', - 'public_key': 'test_key', - 'active': False, - 'test': True - } - - # Test write - with self.lock: - original_peers = self.peers.copy() - self.peers.append(test_peer) - self._save_peers() - - # Test read - with self.lock: - loaded_peers = self.peers.copy() - # Remove test peer - self.peers = [p for p in self.peers if not p.get('test', False)] - self._save_peers() - - # Restore original state - with self.lock: - self.peers = original_peers - self._save_peers() - - return { - 'success': True, - 'message': 'Filesystem access working', - 'read_write': True - } - except Exception as e: - return { - 'success': False, - 'message': f'Filesystem access failed: {str(e)}', - 'error': str(e) - } - - def _test_data_integrity(self) -> Dict[str, Any]: - """Test peer data integrity""" - try: - with self.lock: - # Check if peers data is valid JSON - peers_copy = self.peers.copy() - - # Validate peer structure - valid_peers = 0 - invalid_peers = 0 - - for peer in peers_copy: - if isinstance(peer, dict) and 'peer' in peer and 'ip' in peer: - valid_peers += 1 - else: - invalid_peers += 1 - - return { - 'success': True, - 'message': 'Data integrity check passed', - 'valid_peers': valid_peers, - 'invalid_peers': invalid_peers, - 'total_peers': len(peers_copy) - } - except Exception as e: - return { - 'success': False, - 'message': f'Data integrity check failed: {str(e)}', - 'error': str(e) - } - - def _test_peer_operations(self) -> Dict[str, Any]: - """Test peer operations""" - try: - # Test adding a peer - test_peer = { - 'peer': 'test_operation_peer', - 'ip': '192.168.1.101', - 'public_key': 'test_operation_key', - 'active': False, - 'test': True - } - - # Test add - add_success = self.add_peer(test_peer) - - # Test get - retrieved_peer = self.get_peer('test_operation_peer') - get_success = retrieved_peer is not None - - # Test update - update_success = self.update_peer_ip('test_operation_peer', '192.168.1.102') - - # Test remove - remove_success = self.remove_peer('test_operation_peer') - - return { - 'success': add_success and get_success and update_success and remove_success, - 'message': 'Peer operations working', - 'add_success': add_success, - 'get_success': get_success, - 'update_success': update_success, - 'remove_success': remove_success - } - except Exception as e: - return { - 'success': False, - 'message': f'Peer operations test failed: {str(e)}', - 'error': str(e) - } - - def _load_peers(self): - """Load peers from file""" - try: - # Ensure directory exists - os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) - - if os.path.exists(self.peers_file): - with open(self.peers_file, 'r') as f: - try: - self.peers = json.load(f) - self.logger.info(f"Loaded {len(self.peers)} peers from file") - except Exception as e: - self.logger.error(f"Error loading peers: {e}") - self.peers = [] - else: - self.peers = [] - self.logger.info("No peers file found, starting with empty registry") - except Exception as e: - self.logger.error(f"Error in _load_peers: {e}") - self.peers = [] - - def _save_peers(self): - """Save peers to file""" - try: - # Ensure directory exists - os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) - - with open(self.peers_file, 'w') as f: - json.dump(self.peers, f, indent=2) - - self.logger.info(f"Saved {len(self.peers)} peers to file") - except Exception as e: - self.logger.error(f"Error saving peers: {e}") - - def list_peers(self) -> List[Dict[str, Any]]: - """List all peers""" - with self.lock: - return list(self.peers) - - def get_peer(self, name: str) -> Optional[Dict[str, Any]]: - """Get a specific peer by name""" - with self.lock: - for peer in self.peers: - if peer.get('peer') == name: - return peer - return None - - def add_peer(self, peer_info: Dict[str, Any]) -> bool: - """Add a new peer""" - try: - with self.lock: - if self.get_peer(peer_info.get('peer')): - self.logger.warning(f"Peer {peer_info.get('peer')} already exists") - return False - - # Add timestamp - peer_info['created_at'] = datetime.utcnow().isoformat() - peer_info['active'] = peer_info.get('active', True) - - self.peers.append(peer_info) - self._save_peers() - - self.logger.info(f"Added peer: {peer_info.get('peer')}") - return True - except Exception as e: - self.logger.error(f"Error adding peer: {e}") - return False - - def remove_peer(self, name: str) -> bool: - """Remove a peer""" - try: - with self.lock: - before = len(self.peers) - self.peers = [p for p in self.peers if p.get('peer') != name] - self._save_peers() - - removed = len(self.peers) < before - if removed: - self.logger.info(f"Removed peer: {name}") - else: - self.logger.warning(f"Peer {name} not found for removal") - - return removed - except Exception as e: - self.logger.error(f"Error removing peer {name}: {e}") - return False - - def update_peer(self, name: str, fields: Dict[str, Any]) -> bool: - """Update arbitrary fields on a peer.""" - try: - with self.lock: - for peer in self.peers: - if peer.get('peer') == name: - peer.update(fields) - peer['updated_at'] = datetime.utcnow().isoformat() - self._save_peers() - self.logger.info(f"Updated peer {name}: {list(fields.keys())}") - return True - self.logger.warning(f"Peer {name} not found for update") - return False - except Exception as e: - self.logger.error(f"Error updating peer {name}: {e}") - return False - - def clear_reinstall_flag(self, name: str) -> bool: - """Clear the config_needs_reinstall flag after user downloads new config.""" - return self.update_peer(name, {'config_needs_reinstall': False}) - - def update_peer_ip(self, name: str, new_ip: str) -> bool: - """Update peer IP address""" - try: - with self.lock: - for peer in self.peers: - if peer.get('peer') == name: - old_ip = peer.get('ip') - peer['ip'] = new_ip - peer['updated_at'] = datetime.utcnow().isoformat() - self._save_peers() - - self.logger.info(f"Updated peer {name} IP from {old_ip} to {new_ip}") - return True - - self.logger.warning(f"Peer {name} not found for IP update") - return False - except Exception as e: - self.logger.error(f"Error updating peer {name} IP: {e}") - return False - - def get_peer_stats(self) -> Dict[str, Any]: - """Get peer registry statistics""" - try: - with self.lock: - active_peers = [p for p in self.peers if p.get('active', True)] - inactive_peers = [p for p in self.peers if not p.get('active', True)] - - # Count peers by IP range - ip_ranges = {} - for peer in self.peers: - ip = peer.get('ip', '') - if ip: - range_key = '.'.join(ip.split('.')[:3]) + '.0/24' - ip_ranges[range_key] = ip_ranges.get(range_key, 0) + 1 - - return { - 'total_peers': len(self.peers), - 'active_peers': len(active_peers), - 'inactive_peers': len(inactive_peers), - 'ip_ranges': ip_ranges, - 'timestamp': datetime.utcnow().isoformat() - } - except Exception as e: - self.logger.error(f"Error getting peer stats: {e}") - return { - 'total_peers': 0, - 'active_peers': 0, - 'inactive_peers': 0, - 'ip_ranges': {}, - 'error': str(e), - 'timestamp': datetime.utcnow().isoformat() +#!/usr/bin/env python3 +""" +Peer Registry for Personal Internet Cell +Handles peer registration and management +""" + +import json +import os +import logging +from threading import RLock +from datetime import datetime +from typing import Dict, List, Any, Optional +from base_service_manager import BaseServiceManager + +logger = logging.getLogger(__name__) + +class PeerRegistry(BaseServiceManager): + """Manages peer registration and management""" + + def __init__(self, data_dir: str = '/app/data', config_dir: str = '/app/config'): + super().__init__('peer_registry', data_dir, config_dir) + self.lock = RLock() + self.peers = [] + self.peers_file = os.path.join(data_dir, 'peers.json') + self._load_peers() + + def get_status(self) -> Dict[str, Any]: + """Get peer registry status""" + try: + with self.lock: + status = { + 'running': True, + 'status': 'online', + 'peers_count': len(self.peers), + 'active_peers': len([p for p in self.peers if p.get('active', True)]), + 'inactive_peers': len([p for p in self.peers if not p.get('active', True)]), + 'last_updated': datetime.utcnow().isoformat(), + 'timestamp': datetime.utcnow().isoformat() + } + + return status + except Exception as e: + return self.handle_error(e, "get_status") + + def test_connectivity(self) -> Dict[str, Any]: + """Test peer registry connectivity""" + try: + # Test file system access + fs_test = self._test_filesystem_access() + + # Test peer data integrity + integrity_test = self._test_data_integrity() + + # Test peer operations + operations_test = self._test_peer_operations() + + results = { + 'filesystem_access': fs_test, + 'data_integrity': integrity_test, + 'peer_operations': operations_test, + 'success': fs_test.get('success', False) and integrity_test.get('success', False), + 'timestamp': datetime.utcnow().isoformat() + } + + return results + except Exception as e: + return self.handle_error(e, "test_connectivity") + + def _test_filesystem_access(self) -> Dict[str, Any]: + """Test filesystem access for peer data""" + try: + # Test if we can read/write to the peers file + test_peer = { + 'peer': 'test_peer', + 'ip': '192.168.1.100', + 'public_key': 'test_key', + 'active': False, + 'test': True + } + + # Test write + with self.lock: + original_peers = self.peers.copy() + self.peers.append(test_peer) + self._save_peers() + + # Test read + with self.lock: + loaded_peers = self.peers.copy() + # Remove test peer + self.peers = [p for p in self.peers if not p.get('test', False)] + self._save_peers() + + # Restore original state + with self.lock: + self.peers = original_peers + self._save_peers() + + return { + 'success': True, + 'message': 'Filesystem access working', + 'read_write': True + } + except Exception as e: + return { + 'success': False, + 'message': f'Filesystem access failed: {str(e)}', + 'error': str(e) + } + + def _test_data_integrity(self) -> Dict[str, Any]: + """Test peer data integrity""" + try: + with self.lock: + # Check if peers data is valid JSON + peers_copy = self.peers.copy() + + # Validate peer structure + valid_peers = 0 + invalid_peers = 0 + + for peer in peers_copy: + if isinstance(peer, dict) and 'peer' in peer and 'ip' in peer: + valid_peers += 1 + else: + invalid_peers += 1 + + return { + 'success': True, + 'message': 'Data integrity check passed', + 'valid_peers': valid_peers, + 'invalid_peers': invalid_peers, + 'total_peers': len(peers_copy) + } + except Exception as e: + return { + 'success': False, + 'message': f'Data integrity check failed: {str(e)}', + 'error': str(e) + } + + def _test_peer_operations(self) -> Dict[str, Any]: + """Test peer operations""" + try: + # Test adding a peer + test_peer = { + 'peer': 'test_operation_peer', + 'ip': '192.168.1.101', + 'public_key': 'test_operation_key', + 'active': False, + 'test': True + } + + # Test add + add_success = self.add_peer(test_peer) + + # Test get + retrieved_peer = self.get_peer('test_operation_peer') + get_success = retrieved_peer is not None + + # Test update + update_success = self.update_peer_ip('test_operation_peer', '192.168.1.102') + + # Test remove + remove_success = self.remove_peer('test_operation_peer') + + return { + 'success': add_success and get_success and update_success and remove_success, + 'message': 'Peer operations working', + 'add_success': add_success, + 'get_success': get_success, + 'update_success': update_success, + 'remove_success': remove_success + } + except Exception as e: + return { + 'success': False, + 'message': f'Peer operations test failed: {str(e)}', + 'error': str(e) + } + + def _load_peers(self): + """Load peers from file""" + try: + # Ensure directory exists + os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) + + if os.path.exists(self.peers_file): + with open(self.peers_file, 'r') as f: + try: + self.peers = json.load(f) + self.logger.info(f"Loaded {len(self.peers)} peers from file") + except Exception as e: + self.logger.error(f"Error loading peers: {e}") + self.peers = [] + else: + self.peers = [] + self.logger.info("No peers file found, starting with empty registry") + except Exception as e: + self.logger.error(f"Error in _load_peers: {e}") + self.peers = [] + + def _save_peers(self): + """Save peers to file""" + try: + # Ensure directory exists + os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) + + # Write to a temp file with restrictive perms, then atomically replace. + # peers.json contains WireGuard private keys — must never be world-readable. + tmp_path = self.peers_file + '.tmp' + # Open with O_CREAT|O_WRONLY|O_TRUNC and mode 0o600 so the file is + # created with restrictive permissions from the very first byte. + fd = os.open(tmp_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + try: + with os.fdopen(fd, 'w') as f: + json.dump(self.peers, f, indent=2) + except Exception: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + # Ensure perms are 0o600 even if umask or prior file affected them. + os.chmod(tmp_path, 0o600) + os.replace(tmp_path, self.peers_file) + # Belt-and-braces: also chmod the destination in case it pre-existed + # with looser perms on a filesystem that preserves the destination's mode. + os.chmod(self.peers_file, 0o600) + + self.logger.info(f"Saved {len(self.peers)} peers to file") + except Exception as e: + self.logger.error(f"Error saving peers: {e}") + + def list_peers(self) -> List[Dict[str, Any]]: + """List all peers""" + with self.lock: + return list(self.peers) + + def get_peer(self, name: str) -> Optional[Dict[str, Any]]: + """Get a specific peer by name""" + with self.lock: + for peer in self.peers: + if peer.get('peer') == name: + return peer + return None + + def add_peer(self, peer_info: Dict[str, Any]) -> bool: + """Add a new peer""" + try: + with self.lock: + if self.get_peer(peer_info.get('peer')): + self.logger.warning(f"Peer {peer_info.get('peer')} already exists") + return False + + # Add timestamp + peer_info['created_at'] = datetime.utcnow().isoformat() + peer_info['active'] = peer_info.get('active', True) + + self.peers.append(peer_info) + self._save_peers() + + self.logger.info(f"Added peer: {peer_info.get('peer')}") + return True + except Exception as e: + self.logger.error(f"Error adding peer: {e}") + return False + + def remove_peer(self, name: str) -> bool: + """Remove a peer""" + try: + with self.lock: + before = len(self.peers) + self.peers = [p for p in self.peers if p.get('peer') != name] + self._save_peers() + + removed = len(self.peers) < before + if removed: + self.logger.info(f"Removed peer: {name}") + else: + self.logger.warning(f"Peer {name} not found for removal") + + return removed + except Exception as e: + self.logger.error(f"Error removing peer {name}: {e}") + return False + + def update_peer(self, name: str, fields: Dict[str, Any]) -> bool: + """Update arbitrary fields on a peer.""" + try: + with self.lock: + for peer in self.peers: + if peer.get('peer') == name: + peer.update(fields) + peer['updated_at'] = datetime.utcnow().isoformat() + self._save_peers() + self.logger.info(f"Updated peer {name}: {list(fields.keys())}") + return True + self.logger.warning(f"Peer {name} not found for update") + return False + except Exception as e: + self.logger.error(f"Error updating peer {name}: {e}") + return False + + def clear_reinstall_flag(self, name: str) -> bool: + """Clear the config_needs_reinstall flag after user downloads new config.""" + return self.update_peer(name, {'config_needs_reinstall': False}) + + def update_peer_ip(self, name: str, new_ip: str) -> bool: + """Update peer IP address""" + try: + with self.lock: + for peer in self.peers: + if peer.get('peer') == name: + old_ip = peer.get('ip') + peer['ip'] = new_ip + peer['updated_at'] = datetime.utcnow().isoformat() + self._save_peers() + + self.logger.info(f"Updated peer {name} IP from {old_ip} to {new_ip}") + return True + + self.logger.warning(f"Peer {name} not found for IP update") + return False + except Exception as e: + self.logger.error(f"Error updating peer {name} IP: {e}") + return False + + def get_peer_stats(self) -> Dict[str, Any]: + """Get peer registry statistics""" + try: + with self.lock: + active_peers = [p for p in self.peers if p.get('active', True)] + inactive_peers = [p for p in self.peers if not p.get('active', True)] + + # Count peers by IP range + ip_ranges = {} + for peer in self.peers: + ip = peer.get('ip', '') + if ip: + range_key = '.'.join(ip.split('.')[:3]) + '.0/24' + ip_ranges[range_key] = ip_ranges.get(range_key, 0) + 1 + + return { + 'total_peers': len(self.peers), + 'active_peers': len(active_peers), + 'inactive_peers': len(inactive_peers), + 'ip_ranges': ip_ranges, + 'timestamp': datetime.utcnow().isoformat() + } + except Exception as e: + self.logger.error(f"Error getting peer stats: {e}") + return { + 'total_peers': 0, + 'active_peers': 0, + 'inactive_peers': 0, + 'ip_ranges': {}, + 'error': str(e), + 'timestamp': datetime.utcnow().isoformat() } \ No newline at end of file diff --git a/api/routing_manager.py b/api/routing_manager.py index 024c151..fde7e44 100644 --- a/api/routing_manager.py +++ b/api/routing_manager.py @@ -224,6 +224,22 @@ class RoutingManager(BaseServiceManager): def add_exit_node(self, peer_name: str, peer_ip: str, allowed_domains: List[str] = None) -> bool: """Add exit node configuration""" + # Validation — peer_ip flows into `ip route add default via `; argv + # injection / shell-meta in name would reach iptables/ip via _apply_exit_node. + if not isinstance(peer_name, str) or not re.match(r'^[a-zA-Z0-9_.-]{1,64}$', peer_name): + logger.error(f"add_exit_node: invalid peer_name {peer_name!r}") + return {'success': False, 'error': f'invalid input: peer_name {peer_name!r}'} + try: + ipaddress.ip_address(peer_ip) + except (ValueError, TypeError): + logger.error(f"add_exit_node: invalid peer_ip {peer_ip!r}") + return {'success': False, 'error': f'invalid input: peer_ip {peer_ip!r}'} + if allowed_domains is not None: + if not isinstance(allowed_domains, list): + return {'success': False, 'error': 'invalid input: allowed_domains must be a list'} + for d in allowed_domains: + if not isinstance(d, str) or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', d): + return {'success': False, 'error': f'invalid input: domain {d!r}'} try: rules = self._load_rules() @@ -251,6 +267,23 @@ class RoutingManager(BaseServiceManager): def add_bridge_route(self, source_peer: str, target_peer: str, allowed_networks: List[str]) -> bool: """Add bridge route between peers""" + # source_peer is a name label; target_peer flows into iptables `-d` so must be + # an IP/CIDR. allowed_networks flows into iptables `-s` so must all be CIDRs. + if not isinstance(source_peer, str) or not re.match(r'^[a-zA-Z0-9_.-]{1,64}$', source_peer): + logger.error(f"add_bridge_route: invalid source_peer {source_peer!r}") + return {'success': False, 'error': f'invalid input: source_peer {source_peer!r}'} + try: + ipaddress.ip_network(target_peer, strict=False) + except (ValueError, TypeError): + logger.error(f"add_bridge_route: invalid target_peer {target_peer!r}") + return {'success': False, 'error': f'invalid input: target_peer must be IP/CIDR, got {target_peer!r}'} + if not isinstance(allowed_networks, list) or not allowed_networks: + return {'success': False, 'error': 'invalid input: allowed_networks must be a non-empty list'} + for n in allowed_networks: + try: + ipaddress.ip_network(n, strict=False) + except (ValueError, TypeError): + return {'success': False, 'error': f'invalid input: network {n!r}'} try: rules = self._load_rules() @@ -279,6 +312,22 @@ class RoutingManager(BaseServiceManager): def add_split_route(self, network: str, exit_peer: str, fallback_peer: str = None) -> bool: """Add split routing rule""" + # network flows into `ip route add `; exit_peer flows into `via `. + try: + ipaddress.ip_network(network, strict=False) + except (ValueError, TypeError): + logger.error(f"add_split_route: invalid network {network!r}") + return {'success': False, 'error': f'invalid input: network {network!r}'} + try: + ipaddress.ip_address(exit_peer) + except (ValueError, TypeError): + logger.error(f"add_split_route: invalid exit_peer {exit_peer!r}") + return {'success': False, 'error': f'invalid input: exit_peer must be an IP, got {exit_peer!r}'} + if fallback_peer is not None: + try: + ipaddress.ip_address(fallback_peer) + except (ValueError, TypeError): + return {'success': False, 'error': f'invalid input: fallback_peer must be an IP, got {fallback_peer!r}'} try: rules = self._load_rules() diff --git a/api/vault_manager.py b/api/vault_manager.py index 104b94b..58621a0 100644 --- a/api/vault_manager.py +++ b/api/vault_manager.py @@ -162,10 +162,26 @@ class VaultManager(BaseServiceManager): if self.fernet_key_file.exists(): with open(self.fernet_key_file, "rb") as f: self.fernet_key = f.read() + # SECURITY: ensure key file is owner-only readable on every load + # in case it was created with looser perms by an older version. + try: + os.chmod(str(self.fernet_key_file), 0o600) + except OSError: + pass else: self.fernet_key = Fernet.generate_key() - with open(self.fernet_key_file, "wb") as f: + # SECURITY: create the key file with 0o600 from the first byte + # so the secret is never world-readable, even momentarily. + fd = os.open( + str(self.fernet_key_file), + os.O_WRONLY | os.O_CREAT | os.O_TRUNC, + 0o600, + ) + with os.fdopen(fd, "wb") as f: f.write(self.fernet_key) + # Belt-and-braces chmod in case umask or a pre-existing file + # left wider permissions in place. + os.chmod(str(self.fernet_key_file), 0o600) self.fernet = Fernet(self.fernet_key) except (PermissionError, OSError): self.fernet_key = Fernet.generate_key() diff --git a/api/wireguard_manager.py b/api/wireguard_manager.py index abc5d12..473c6aa 100644 --- a/api/wireguard_manager.py +++ b/api/wireguard_manager.py @@ -206,6 +206,62 @@ class WireGuardManager(BaseServiceManager): """Return split-tunnel AllowedIPs: VPN subnet + Docker bridge.""" return f'{self._get_configured_network()}, 172.20.0.0/16' + def _load_registered_peers(self) -> list: + """Read active peers from peers.json for wg0.conf reconstruction after bootstrap.""" + import json as _json + peers_file = os.path.join(self.data_dir, 'peers.json') + try: + with open(peers_file) as f: + peers = _json.load(f) + return [ + p for p in peers + if isinstance(p, dict) + and p.get('active', True) + and p.get('public_key') + and p.get('ip') + ] + except Exception: + return [] + + def _sync_keys_from_conf(self) -> None: + """Sync the API's key store from wg0.conf so both agree on the server identity. + + linuxserver/wireguard auto-generates a PrivateKey on first container start. + The API generates its own key independently. Any time apply_config() runs, + read the PrivateKey from wg0.conf (the container's authoritative source) and + update the API's key-store files to match — keeping get_keys() consistent. + """ + import base64 as _b64 + cf = self._config_file() + if not os.path.exists(cf): + return + try: + with open(cf) as f: + raw = f.read() + for line in raw.splitlines(): + stripped = line.strip() + if stripped.startswith('PrivateKey'): + conf_priv = stripped.split('=', 1)[1].strip() + api_keys = self.get_keys() + if conf_priv == api_keys.get('private_key'): + return # already in sync + # Derive public key from private key and update both files + from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey + priv_bytes = _b64.b64decode(conf_priv) + priv_obj = X25519PrivateKey.from_private_bytes(priv_bytes) + pub_bytes = priv_obj.public_key().public_bytes_raw() + pub_b64 = _b64.b64encode(pub_bytes).decode() + priv_file = os.path.join(self.keys_dir, 'private.key') + pub_file = os.path.join(self.keys_dir, 'public.key') + with open(priv_file, 'wb') as f: + f.write(priv_bytes) + with open(pub_file, 'wb') as f: + f.write(pub_bytes) + logger.info(f'wg: key-store synced from wg0.conf (new pub={pub_b64[:16]}...)') + return + except Exception as e: + logger.warning(f'_sync_keys_from_conf failed (non-fatal): {e}') + def apply_config(self, config: Dict[str, Any]) -> Dict[str, Any]: """Update wg0.conf interface fields and restart cell-wireguard.""" restarted = [] @@ -215,12 +271,26 @@ class WireGuardManager(BaseServiceManager): warnings.append('wg0.conf not found — skipping') return {'restarted': restarted, 'warnings': warnings} try: + # Sync the API key-store from wg0.conf before doing anything else. + # linuxserver/wireguard auto-generates its own key; this keeps both in sync + # so get_peer_config() always embeds the correct server public key. + self._sync_keys_from_conf() + with open(cf) as f: raw = f.read() # Bootstrap from generate_config() if file is empty or has no [Interface] if not raw.strip() or '[Interface]' not in raw: raw = self.generate_config() + # Restore all registered peers so clients can reconnect immediately + for peer in self._load_registered_peers(): + raw += ( + f'\n[Peer]\n' + f'# {peer.get("peer", "unknown")}\n' + f'PublicKey = {peer["public_key"]}\n' + f'AllowedIPs = {peer["ip"]}/32\n' + f'PersistentKeepalive = 25\n' + ) with open(cf, 'w') as f: f.write(raw) warnings.append('wg0.conf was empty — regenerated from keys') @@ -389,12 +459,38 @@ class WireGuardManager(BaseServiceManager): Unlike add_peer(), allows a subnet CIDR as AllowedIPs (whole remote VPN range). The endpoint is expected to already include the port (e.g. '1.2.3.4:51820'). """ - import ipaddress + import ipaddress, re as _re + # Validate public_key strictly — empty/garbled keys later cause remove_peer("") + # to wipe ALL peer blocks via substring match. + if not isinstance(public_key, str) or not _re.match(r'^[A-Za-z0-9+/]{43}=$', public_key.strip()): + logger.error(f'add_cell_peer: invalid public_key') + return False + # Validate name — reject newlines/brackets that could inject config blocks + if not isinstance(name, str) or not _re.match(r'^[A-Za-z0-9_. -]{1,64}$', name): + logger.error(f'add_cell_peer: invalid name {name!r}') + return False + # Validate endpoint as host:port — reject newlines and out-of-range ports + if endpoint: + if not isinstance(endpoint, str) or not _re.match(r'^[A-Za-z0-9._-]+:\d{1,5}$', endpoint): + logger.error(f'add_cell_peer: invalid endpoint {endpoint!r}') + return False + try: + _port = int(endpoint.rsplit(':', 1)[1]) + if not (1 <= _port <= 65535): + logger.error(f'add_cell_peer: endpoint port out of range: {endpoint!r}') + return False + except (ValueError, IndexError): + logger.error(f'add_cell_peer: invalid endpoint port: {endpoint!r}') + return False try: ipaddress.ip_network(vpn_subnet, strict=False) except ValueError as e: logger.error(f'add_cell_peer: invalid vpn_subnet {vpn_subnet!r}: {e}') return False + # Reject any whitespace/newlines in vpn_subnet that ip_network() may have tolerated + if any(c.isspace() for c in vpn_subnet): + logger.error(f'add_cell_peer: vpn_subnet contains whitespace: {vpn_subnet!r}') + return False try: content = self._read_config() peer_block = ( @@ -461,6 +557,16 @@ class WireGuardManager(BaseServiceManager): def update_peer_ip(self, public_key: str, new_ip: str) -> bool: """Update AllowedIPs for the peer with the given public key.""" + import ipaddress + # Reject whitespace/newlines that ip_network() may tolerate but would inject config + if not isinstance(new_ip, str) or any(c.isspace() for c in new_ip): + logger.error(f'update_peer_ip: invalid new_ip {new_ip!r}') + return False + try: + ipaddress.ip_network(new_ip, strict=False) + except ValueError as e: + logger.error(f'update_peer_ip: invalid new_ip {new_ip!r}: {e}') + return False content = self._read_config() if f'PublicKey = {public_key}' not in content: return False @@ -667,6 +773,25 @@ class WireGuardManager(BaseServiceManager): status = self.get_status() running = status.get('running', False) return {'success': running, 'reachable': running, 'status': status.get('status')} + # Validate target_ip — reject argv injection (any string starting with '-' would + # be parsed by ping as a flag) and any non-IP input. + import ipaddress + if not isinstance(peer_ip, str) or peer_ip.startswith('-'): + return { + 'peer_ip': peer_ip, + 'ping_success': False, + 'ping_output': '', + 'ping_error': 'invalid peer_ip', + } + try: + ipaddress.ip_address(peer_ip) + except ValueError: + return { + 'peer_ip': peer_ip, + 'ping_success': False, + 'ping_output': '', + 'ping_error': 'invalid peer_ip', + } try: result = subprocess.run( ['ping', '-c', '1', '-W', '2', peer_ip], diff --git a/config/api/.gitkeep b/config/api/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/api/caddy/Caddyfile b/config/api/caddy/Caddyfile index b5fe71c..98b1b29 100644 --- a/config/api/caddy/Caddyfile +++ b/config/api/caddy/Caddyfile @@ -1,92 +1,57 @@ -# Personal Internet Cell - Caddy Configuration -# This serves as the main reverse proxy and TLS termination point - -# Global settings { - # Auto-generate certificates for .cell domains - auto_https disable_redirects + auto_https off } -# Main cell domain - replace 'mycell' with your cell name -mycell.cell { - # TLS with internal CA - tls internal - - # API endpoints +# Main cell domain — no service-IP restriction needed +http://pic0.lan, http://172.20.0.2:80 { handle /api/* { reverse_proxy cell-api:3000 } - - # Web UI - handle / { - reverse_proxy cell-webui:80 - } - - # Email web interface - handle /mail { - reverse_proxy cell-mail:80 - } - - # Calendar and contacts - handle /calendar { + handle /calendar* { reverse_proxy cell-radicale:5232 } - - # File storage - handle /files { - reverse_proxy cell-webdav:80 - } - - # DNS management interface - handle /dns { - reverse_proxy cell-dns:8080 - } - - # RainLoop Webmail - handle_path /webmail/* { - reverse_proxy cell-rainloop:8888 - } - - # FileGator File Browser - handle /files-ui* { + handle /files* { reverse_proxy cell-filegator:8080 } + handle /webmail* { + reverse_proxy cell-rainloop:8888 + } + handle { + reverse_proxy cell-webui:80 + } } -# Peer cell domains (will be dynamically added) -# Example: bob.cell { -# reverse_proxy cell-wireguard:51820 -# } +# Per-service virtual IPs — each gets its own IP so iptables can target them +http://calendar.lan, http://172.20.0.21:80 { + reverse_proxy cell-radicale:5232 +} -# Local development -localhost { - # API endpoints +http://files.lan, http://172.20.0.22:80 { + reverse_proxy cell-filegator:8080 +} + +http://mail.lan, http://webmail.lan, http://172.20.0.23:80 { + reverse_proxy cell-rainloop:8888 +} + +http://webdav.lan, http://172.20.0.24:80 { + reverse_proxy cell-webdav:80 +} + +http://api.lan { + reverse_proxy cell-api:3000 +} + +http://webui.lan { + reverse_proxy cell-webui:80 +} + +# Catch-all for direct IP / localhost +:80 { handle /api/* { reverse_proxy cell-api:3000 } - - # Web UI - handle / { + handle { reverse_proxy cell-webui:80 } - - # Email web interface - handle /mail { - reverse_proxy cell-mail:80 - } - - # Calendar and contacts - handle /calendar { - reverse_proxy cell-radicale:5232 - } - - # File storage - handle /files { - reverse_proxy cell-webdav:80 - } - - # DNS management interface - handle /dns { - reverse_proxy cell-dns:8080 - } -} \ No newline at end of file +} diff --git a/config/api/calendar.json b/config/api/calendar.json new file mode 100644 index 0000000..81fbc2e --- /dev/null +++ b/config/api/calendar.json @@ -0,0 +1,3 @@ +{ + "port": 5233 +} \ No newline at end of file diff --git a/config/api/cell_config.json b/config/api/cell_config.json new file mode 100644 index 0000000..1a4a6a6 --- /dev/null +++ b/config/api/cell_config.json @@ -0,0 +1,22 @@ +{ + "_identity": { + "cell_name": "pic0", + "domain": "dec", + "ip_range": "172.20.0.0/16", + "wireguard_port": 51820 + }, + "_pending_restart": { + "needs_restart": false, + "changes": [], + "containers": [], + "network_recreate": false + }, + "calendar": { + "port": 5233 + }, + "wireguard": { + "port": 51820, + "address": "", + "private_key": "" + } +} \ No newline at end of file diff --git a/config/caddy/Caddyfile b/config/caddy/Caddyfile index f385c8a..22cbd0b 100644 --- a/config/caddy/Caddyfile +++ b/config/caddy/Caddyfile @@ -3,7 +3,7 @@ } # Main cell domain — no service-IP restriction needed -http://mycell.cell, http://172.20.0.2:80 { +http://pic0.dec, http://172.20.0.2:80 { handle /api/* { reverse_proxy cell-api:3000 } @@ -22,26 +22,30 @@ http://mycell.cell, http://172.20.0.2:80 { } # Per-service virtual IPs — each gets its own IP so iptables can target them -http://calendar.cell, http://172.20.0.21:80 { +http://calendar.dec, http://172.20.0.21:80 { reverse_proxy cell-radicale:5232 } -http://files.cell, http://172.20.0.22:80 { +http://files.dec, http://172.20.0.22:80 { reverse_proxy cell-filegator:8080 } -http://mail.cell, http://webmail.cell, http://172.20.0.23:80 { +http://mail.dec, http://webmail.dec, http://172.20.0.23:80 { reverse_proxy cell-rainloop:8888 } -http://webdav.cell, http://172.20.0.24:80 { +http://webdav.dec, http://172.20.0.24:80 { reverse_proxy cell-webdav:80 } -http://api.cell { +http://api.dec { reverse_proxy cell-api:3000 } +http://webui.dec { + reverse_proxy cell-webui:80 +} + # Catch-all for direct IP / localhost :80 { handle /api/* { diff --git a/config/dhcp/.gitkeep b/config/dhcp/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/dns/.gitkeep b/config/dns/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/dns/Corefile b/config/dns/Corefile index ad1f4c2..74bf426 100644 --- a/config/dns/Corefile +++ b/config/dns/Corefile @@ -5,12 +5,8 @@ health } -dev { - file /data/dev.zone +dec { + file /data/dec.zone log } -local.dev { - file /data/local.zone - log -} diff --git a/config/ntp/.gitkeep b/config/ntp/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/webdav/.gitkeep b/config/webdav/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docker-compose.yml b/docker-compose.yml index fe1aa15..5706bb1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -199,6 +199,7 @@ services: - ./data/api:/app/data - ./data/dns:/app/data/dns - ./config/api:/app/config + - ./config/caddy:/app/config-caddy - ./config/wireguard:/app/config/wireguard - ./config/dns:/app/config/dns - ./data/logs:/app/api/data/logs diff --git a/tests/e2e/api/test_peer_endpoints.py b/tests/e2e/api/test_peer_endpoints.py index aea2ec5..33fbc56 100644 --- a/tests/e2e/api/test_peer_endpoints.py +++ b/tests/e2e/api/test_peer_endpoints.py @@ -4,8 +4,8 @@ Scenarios 20, 21: Peer role access scoping. Tests cover: - Peer is blocked from admin-only routes (config, wireguard, peer list) - Peer can access /api/peer/dashboard and /api/peer/services - - Dashboard response shape (peer_name, online, rx_bytes, tx_bytes, allowed_ips) - - Services response shape (wireguard, email, caldav, webdav sections) + - Dashboard response shape (name, online, transfer_rx, transfer_tx, service_urls) + - Services response shape (wireguard, email, caldav, files sections) - Peer can change their own password and use the new credential - Peer cannot call admin/reset-password """ @@ -54,12 +54,24 @@ def test_peer_dashboard_has_expected_fields(peer_client): r = peer_client.get('/api/peer/dashboard') assert r.status_code == 200 data = r.json() - missing = [f for f in ('peer_name', 'online', 'rx_bytes', 'tx_bytes', 'allowed_ips') if f not in data] + missing = [f for f in ('name', 'online', 'transfer_rx', 'transfer_tx', 'allowed_ips', 'service_urls') if f not in data] assert not missing, ( f"Dashboard response missing fields {missing}. Got keys: {list(data.keys())}" ) +def test_peer_dashboard_no_stale_field_names(peer_client): + """Verify renamed fields are gone — old names cause silent UI blanks.""" + r = peer_client.get('/api/peer/dashboard') + assert r.status_code == 200 + data = r.json() + stale = [f for f in ('peer_name', 'rx_bytes', 'tx_bytes') if f in data] + assert not stale, ( + f"Dashboard response still has stale fields {stale} — " + "PeerDashboard.jsx reads name/transfer_rx/transfer_tx" + ) + + def test_peer_can_access_own_services(peer_client): r = peer_client.get('/api/peer/services') assert r.status_code == 200, ( @@ -71,12 +83,61 @@ def test_peer_services_has_expected_sections(peer_client): r = peer_client.get('/api/peer/services') assert r.status_code == 200 data = r.json() - missing = [k for k in ('wireguard', 'email', 'caldav', 'webdav') if k not in data] + missing = [k for k in ('wireguard', 'email', 'caldav', 'files') if k not in data] assert not missing, ( f"Services response missing sections {missing}. Got keys: {list(data.keys())}" ) +def test_peer_services_no_stale_keys(peer_client): + """Verify renamed keys are gone — old names cause silent UI blanks.""" + r = peer_client.get('/api/peer/services') + assert r.status_code == 200 + data = r.json() + assert 'webdav' not in data, ( + "'webdav' still present at top level — MyServices.jsx reads 'files'" + ) + + +def test_peer_services_email_structure(peer_client): + """Email section must use nested smtp/imap objects and email.address.""" + r = peer_client.get('/api/peer/services') + assert r.status_code == 200 + email = r.json().get('email', {}) + assert 'address' in email, f"email.address missing; email keys: {list(email)}" + assert 'smtp' in email and isinstance(email['smtp'], dict), \ + f"email.smtp must be a dict; got: {email.get('smtp')}" + assert 'imap' in email and isinstance(email['imap'], dict), \ + f"email.imap must be a dict; got: {email.get('imap')}" + assert 'host' in email['smtp'], "email.smtp.host missing" + assert 'host' in email['imap'], "email.imap.host missing" + assert 'imap_host' not in email, "'imap_host' still flat — should be email.imap.host" + assert 'smtp_host' not in email, "'smtp_host' still flat — should be email.smtp.host" + + +def test_peer_services_caldav_url_uses_calendar_domain(peer_client): + """CalDAV URL must be calendar.dev, not radicale.dev:5232.""" + r = peer_client.get('/api/peer/services') + assert r.status_code == 200 + url = r.json().get('caldav', {}).get('url', '') + assert 'radicale' not in url, \ + f"CalDAV URL must not contain 'radicale' — no radicale.dev DNS record; got: {url}" + assert ':5232' not in url, \ + f"CalDAV URL exposes port 5232 — use Caddy-proxied URL; got: {url}" + + +def test_peer_services_wireguard_dns_not_vpn_gateway(peer_client): + """WireGuard DNS must be the CoreDNS IP, not the VPN gateway 10.0.0.1.""" + r = peer_client.get('/api/peer/services') + assert r.status_code == 200 + dns = r.json().get('wireguard', {}).get('dns', '') + assert dns != '10.0.0.1', ( + "wireguard.dns is 10.0.0.1 (WireGuard VPN gateway) — " + "DNS queries to 10.0.0.1 fail because the VPN server doesn't run a DNS resolver; " + "must be the CoreDNS container IP" + ) + + # --------------------------------------------------------------------------- # Auth management — scoping # --------------------------------------------------------------------------- diff --git a/tests/e2e/ui/test_peer_dashboard.py b/tests/e2e/ui/test_peer_dashboard.py index 58a8d18..7d15ff7 100644 --- a/tests/e2e/ui/test_peer_dashboard.py +++ b/tests/e2e/ui/test_peer_dashboard.py @@ -3,16 +3,22 @@ Peer dashboard and My Services page tests. Scenarios: 12. Peer sees their own dashboard (PeerDashboard.jsx renders peer.name as

) - 13. Peer's My Services page loads and shows the WireGuard VPN section + 13. Peer's My Services page loads and shows all service sections + 14. Peer dashboard shows service icon links (calendar, files, mail, webdav) + 15. My Services shows correct CalDAV URL (calendar.dev not radicale.dev:5232) + 16. My Services shows email address field (not username) Key selectors from PeerDashboard.jsx: - - h1 shows peer.name (line 61: `{peer.name || 'My Dashboard'}`) - - "VPN Address" stat card label (line 76) - - "Quick Access" → "My Services" link (line 117-119) + - h1 shows peer.name (peer.name from /api/peer/dashboard) + - "VPN Address" stat card label + - "Quick Access" section with service icon links from service_urls + - "My Services" link Key selectors from MyServices.jsx: - - h2 "WireGuard VPN" (line 93) + - h2 "WireGuard VPN" - h2 "Email", h2 "Calendar & Contacts", h2 "Files" + - "Address" label for email (not "Username") + - "CalDAV URL" label with calendar.dev value """ import pytest @@ -131,3 +137,78 @@ def test_peer_my_services_shows_files_section(peer_page, webui_base): pytest.xfail( "Files section heading not found on /my-services" ) + + +# ── 14. Service icon links ──────────────────────────────────────────────────── + +def test_peer_dashboard_has_calendar_link(peer_page, webui_base): + """PeerDashboard Quick Access section renders a Calendar icon link.""" + page, _ = peer_page + page.wait_for_load_state('networkidle') + try: + page.wait_for_selector('a:has-text("Calendar")', timeout=5000) + except Exception: + pytest.xfail( + "Calendar link not found on peer dashboard Quick Access — " + "check that service_urls.calendar is populated and PeerDashboard.jsx renders it" + ) + + +def test_peer_dashboard_has_files_link(peer_page, webui_base): + """PeerDashboard Quick Access section renders a Files icon link.""" + page, _ = peer_page + page.wait_for_load_state('networkidle') + try: + page.wait_for_selector('a:has-text("Files")', timeout=5000) + except Exception: + pytest.xfail( + "Files link not found on peer dashboard Quick Access" + ) + + +def test_peer_dashboard_has_mail_link(peer_page, webui_base): + """PeerDashboard Quick Access section renders a Mail icon link.""" + page, _ = peer_page + page.wait_for_load_state('networkidle') + try: + page.wait_for_selector('a:has-text("Mail")', timeout=5000) + except Exception: + pytest.xfail( + "Mail link not found on peer dashboard Quick Access" + ) + + +# ── 15. CalDAV URL correctness ──────────────────────────────────────────────── + +def test_peer_my_services_caldav_url_no_radicale(peer_page, webui_base): + """CalDAV URL shown in My Services must not contain 'radicale' (no DNS record).""" + page, _ = peer_page + page.goto(f"{webui_base}/my-services") + page.wait_for_load_state('networkidle') + try: + # If radicale.dev appears as CalDAV URL it means the bug is back + radicale_url = page.query_selector('text=radicale') + assert radicale_url is None, ( + "Found 'radicale' text on My Services page — " + "CalDAV URL should be calendar.dev, not radicale.dev:5232" + ) + except AssertionError: + raise + except Exception: + pass # page didn't load — other tests cover that + + +# ── 16. Email address display ───────────────────────────────────────────────── + +def test_peer_my_services_shows_address_label(peer_page, webui_base): + """MyServices.jsx renders 'Address' label for email (reads email.address).""" + page, _ = peer_page + page.goto(f"{webui_base}/my-services") + page.wait_for_load_state('networkidle') + try: + page.wait_for_selector('text=Address', timeout=5000) + except Exception: + pytest.xfail( + "'Address' label not found on My Services email section — " + "check that email.address is populated in /api/peer/services" + ) diff --git a/tests/e2e/wg/conftest.py b/tests/e2e/wg/conftest.py index 359e01d..ac1f80b 100644 --- a/tests/e2e/wg/conftest.py +++ b/tests/e2e/wg/conftest.py @@ -1,10 +1,16 @@ import os +import shutil import pytest import tempfile import secrets from helpers.wg_runner import WGInterface, build_wg_config, cleanup_stale_e2e_interfaces +def pytest_configure(config): + if not shutil.which('wg-quick'): + pytest.skip('wg-quick not found — skipping WireGuard E2E tests', allow_module_level=True) + + @pytest.fixture(scope='session', autouse=True) def cleanup_stale_wg_interfaces(): cleanup_stale_e2e_interfaces() diff --git a/tests/e2e/wg/test_caddy_routing.py b/tests/e2e/wg/test_caddy_routing.py new file mode 100644 index 0000000..e4128ba --- /dev/null +++ b/tests/e2e/wg/test_caddy_routing.py @@ -0,0 +1,275 @@ +""" +WireGuard E2E: Caddy per-domain routing correctness. + +Scenarios covered: + 35. api. proxies to the API (returns JSON), not the WebUI + 36. calendar. via VIP proxies to Radicale, not the WebUI + 37. files. via VIP proxies to Filegator, not the WebUI + 38. mail. via VIP proxies to Rainloop, not the WebUI + 39. webdav. via VIP proxies to the WebDAV service, not the WebUI + 40. Direct VIP requests (by IP) go to the correct service + 41. Catch-all :80 serves WebUI for unknown hosts but routes /api/* to API + +The WebUI serves a React app — its HTML starts with ''. +Any service domain that returns that string is incorrectly falling through +to the catch-all :80 block instead of being routed by its Host header. + +These tests require a live PIC stack with WireGuard and are marked `wg`. +They run via `make test-e2e-wg` or `pytest tests/e2e/wg/ -m wg`. +""" + +import subprocess +import pytest + +pytestmark = pytest.mark.wg + +_WEBUI_MARKER = '' + + +def _config(admin_client) -> dict: + r = admin_client.get('/api/config') + return r.json() if r.status_code == 200 else {} + + +def _domain(admin_client) -> str: + return _config(admin_client).get('domain') or 'lan' + + +def _dns_ip(admin_client) -> str: + cfg = _config(admin_client) + return cfg.get('service_ips', {}).get('dns') or '172.20.0.3' + + +def _curl_host(ip: str, host: str, path: str = '/', timeout: int = 8) -> tuple[int, str]: + """ + Make an HTTP request to `ip` with the given Host header. + Returns (http_code, body_snippet). + """ + result = subprocess.run( + ['curl', '-s', '--connect-timeout', '5', + '-H', f'Host: {host}', + '-w', '\n__HTTP_CODE__:%{http_code}', + f'http://{ip}{path}'], + capture_output=True, text=True, timeout=timeout, + ) + output = result.stdout + body = '' + code = 0 + if '__HTTP_CODE__:' in output: + parts = output.rsplit('__HTTP_CODE__:', 1) + body = parts[0].lower() + try: + code = int(parts[1].strip()) + except ValueError: + pass + return code, body + + +def _curl_domain(host: str, path: str = '/', dns_ip: str = '', timeout: int = 8) -> tuple[int, str]: + """Make an HTTP request using curl's --dns-servers to resolve via CoreDNS.""" + cmd = ['curl', '-s', '--connect-timeout', '5', + '-w', '\n__HTTP_CODE__:%{http_code}', + f'http://{host}{path}'] + if dns_ip: + cmd = ['curl', '-s', '--connect-timeout', '5', + '--dns-servers', dns_ip, + '-w', '\n__HTTP_CODE__:%{http_code}', + f'http://{host}{path}'] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + output = result.stdout + body = '' + code = 0 + if '__HTTP_CODE__:' in output: + parts = output.rsplit('__HTTP_CODE__:', 1) + body = parts[0].lower() + try: + code = int(parts[1].strip()) + except ValueError: + pass + return code, body + + +# ── Scenario 35: api. routes to API ─────────────────────────────────── + +def test_api_domain_returns_json_not_webui(connected_peer, admin_client): + """api./api/status must return JSON, not the React WebUI HTML.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'api.{dom}', '/api/status', dns_ip) + assert code not in (0, 000), f"curl to api.{dom}/api/status failed (code {code})" + assert _WEBUI_MARKER not in body, ( + f"api.{dom}/api/status returned WebUI HTML — " + "Caddy is not routing api. to the API; " + "check that the http://api. block exists in the Caddyfile " + "and uses the configured domain (not a stale .cell or .dev TLD)" + ) + assert '{' in body or '"' in body, ( + f"api.{dom}/api/status did not return JSON (body: {body[:100]!r})" + ) + + +def test_api_vip_host_header_routes_to_api(connected_peer, admin_client): + """Caddy routes api. by Host header even when accessed via the Caddy VIP.""" + dom = _domain(admin_client) + code, body = _curl_host('172.20.0.2', f'api.{dom}', '/api/status') + assert _WEBUI_MARKER not in body, ( + f"Host: api.{dom} via 172.20.0.2 returned WebUI HTML — " + "Caddy http://api. block is missing or uses wrong TLD" + ) + + +# ── Scenario 36: calendar. routes to Radicale ──────────────────────── + +def test_calendar_vip_does_not_serve_webui(connected_peer, admin_client): + """calendar. (VIP 172.20.0.21) must proxy to Radicale, not the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'calendar.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to calendar.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"calendar.{dom} returned WebUI HTML — " + "Caddy is not routing calendar. to Radicale. " + "This happens when Caddy has old (e.g. .cell) domain blocks and all " + "traffic falls through to the catch-all :80 block." + ) + + +def test_calendar_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.21 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.21', 'calendar.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.21 (calendar VIP) returned WebUI HTML — " + "Caddy http://calendar., http://172.20.0.21:80 block is missing or stale" + ) + + +# ── Scenario 37: files. routes to Filegator ────────────────────────── + +def test_files_vip_does_not_serve_webui(connected_peer, admin_client): + """files. (VIP 172.20.0.22) must proxy to Filegator, not the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'files.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to files.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"files.{dom} returned WebUI HTML — " + "Caddy is not routing files. to Filegator. " + "Check the http://files., http://172.20.0.22:80 Caddyfile block." + ) + + +def test_files_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.22 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.22', 'files.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.22 (files VIP) returned WebUI HTML — " + "Caddy http://files., http://172.20.0.22:80 block is missing or stale" + ) + + +# ── Scenario 38: mail. routes to Rainloop ──────────────────────────── + +def test_mail_vip_does_not_serve_webui(connected_peer, admin_client): + """mail. (VIP 172.20.0.23) must proxy to Rainloop, not the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'mail.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to mail.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"mail.{dom} returned WebUI HTML — " + "Caddy is not routing mail. to Rainloop." + ) + + +def test_webmail_vip_does_not_serve_webui(connected_peer, admin_client): + """webmail. (alias, same VIP 172.20.0.23) must NOT return the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'webmail.{dom}', '/', dns_ip) + assert _WEBUI_MARKER not in body, ( + f"webmail.{dom} returned WebUI HTML — " + "Caddy http://webmail. block is missing or stale" + ) + + +def test_mail_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.23 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.23', 'mail.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.23 (mail VIP) returned WebUI HTML — " + "Caddy http://mail., http://172.20.0.23:80 block is missing or stale" + ) + + +# ── Scenario 39: webdav. routes to WebDAV ──────────────────────────── + +def test_webdav_vip_does_not_serve_webui(connected_peer, admin_client): + """webdav. (VIP 172.20.0.24) must proxy to the WebDAV service.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'webdav.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to webdav.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"webdav.{dom} returned WebUI HTML — " + "Caddy is not routing webdav. to the WebDAV service." + ) + + +def test_webdav_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.24 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.24', 'webdav.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.24 (webdav VIP) returned WebUI HTML — " + "Caddy http://webdav., http://172.20.0.24:80 block is missing or stale" + ) + + +# ── Scenario 40: VIP IPs without Host header ───────────────────────────────── + +@pytest.mark.parametrize('vip,expected_not', [ + ('172.20.0.21', _WEBUI_MARKER), + ('172.20.0.22', _WEBUI_MARKER), + ('172.20.0.23', _WEBUI_MARKER), + ('172.20.0.24', _WEBUI_MARKER), +]) +def test_vip_direct_access_not_webui(connected_peer, vip, expected_not): + """Each service VIP accessed directly (no special Host) must not return WebUI.""" + code, body = _curl_host(vip, vip) + assert expected_not not in body, ( + f"VIP {vip} returned WebUI HTML — " + "Caddy catch-all :80 is taking over; the per-VIP blocks must listen on port 80" + ) + + +# ── Scenario 41: Catch-all :80 routes API path correctly ───────────────────── + +def test_catchall_api_path_returns_json(connected_peer): + """The catch-all :80 block must route /api/* to the API (not WebUI).""" + code, body = _curl_host('172.20.0.2', 'localhost', '/api/status') + assert _WEBUI_MARKER not in body, ( + "Catch-all :80 returned WebUI HTML for /api/status — " + "the `handle /api/*` directive in the :80 block is missing or wrong" + ) + assert '{' in body or '"' in body, ( + f"/api/status via catch-all did not return JSON (body: {body[:100]!r})" + ) + + +def test_catchall_root_serves_webui(connected_peer): + """The catch-all :80 block serves the WebUI for the root path.""" + code, body = _curl_host('172.20.0.2', 'localhost', '/') + assert _WEBUI_MARKER in body, ( + "Catch-all :80 / did not return WebUI HTML — " + "something is broken with the catch-all :80 block" + ) + + +# ── Scenario extra: stale TLD detection ────────────────────────────────────── + +def test_caddy_does_not_route_cell_tld(connected_peer): + """Caddy must NOT have active routing for .cell domains — they are from old config.""" + code, body = _curl_host('172.20.0.2', 'calendar.cell', '/') + assert _WEBUI_MARKER in body or code in (0, 404, 502, 503), ( + "Caddy is still routing calendar.cell — stale .cell blocks remain in config. " + "Check that write_caddyfile() is writing to the correct path that Caddy reads." + ) diff --git a/tests/e2e/wg/test_wg_domain_access.py b/tests/e2e/wg/test_wg_domain_access.py new file mode 100644 index 0000000..2a6bcb0 --- /dev/null +++ b/tests/e2e/wg/test_wg_domain_access.py @@ -0,0 +1,232 @@ +""" +WireGuard E2E: domain name resolution and HTTP access through the VPN tunnel. + +Scenarios covered: + 30. All service subdomains resolve to the expected IPs via the CoreDNS server + 31. Direct HTTP access to each service IP works through the VPN + 32. HTTP access via domain names works through the VPN (DNS + routing) + 33. WireGuard config downloaded via /api/peer/services has correct DNS field + 34. Peer config DNS points to CoreDNS, not the WireGuard VPN gateway + +Domain name is read from the live API config — these tests do NOT hardcode .dev or .lan. + +These tests require a live PIC stack with WireGuard and are marked `wg`. +They run via `make test-e2e-wg` or `pytest tests/e2e/wg/ -m wg`. +""" + +import subprocess +import pytest + +pytestmark = pytest.mark.wg + +# Subdomain → expected offset in ip_utils.CONTAINER_OFFSETS / VIP list. +# These are the sub-names, not full FQDNs — the TLD is fetched from config. +SUBDOMAINS_TO_IPS = { + 'api': '172.20.0.2', # must route through Caddy (not API container direct) + 'webui': '172.20.0.2', # must route through Caddy + 'calendar': '172.20.0.21', # Caddy VIP for CalDAV + 'files': '172.20.0.22', # Caddy VIP for Filegator + 'mail': '172.20.0.23', # Caddy VIP for Rainloop + 'webmail': '172.20.0.23', # alias for mail VIP + 'webdav': '172.20.0.24', # Caddy VIP for WebDAV +} + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _config(admin_client) -> dict: + r = admin_client.get('/api/config') + return r.json() if r.status_code == 200 else {} + + +def _dns_ip(admin_client) -> str: + cfg = _config(admin_client) + return cfg.get('service_ips', {}).get('dns') or '172.20.0.3' + + +def _domain(admin_client) -> str: + """Return the configured cell domain (e.g. 'lan', 'dev', 'home').""" + return _config(admin_client).get('domain') or 'lan' + + +def _cell_name(admin_client) -> str: + return _config(admin_client).get('cell_name') or 'pic0' + + +# ── Scenario 30: DNS resolution ─────────────────────────────────────────────── + +@pytest.mark.parametrize('subdomain,expected_ip', list(SUBDOMAINS_TO_IPS.items())) +def test_service_domain_resolves_to_expected_ip(connected_peer, admin_client, subdomain, expected_ip): + """Each service subdomain resolves to the correct IP via CoreDNS. + + The full FQDN is built from the configured domain — not hardcoded to any TLD. + """ + dns_ip = _dns_ip(admin_client) + dom = _domain(admin_client) + fqdn = f'{subdomain}.{dom}' + result = subprocess.run( + ['dig', f'@{dns_ip}', fqdn, 'A', '+short', '+time=5'], + capture_output=True, text=True, timeout=10, + ) + assert result.returncode == 0, f"dig failed for {fqdn}: {result.stderr}" + resolved = result.stdout.strip() + assert resolved == expected_ip, ( + f"{fqdn} resolved to {resolved!r}, expected {expected_ip}. " + f"DNS server: {dns_ip}, configured domain: {dom!r}" + ) + + +def test_cell_hostname_resolves_to_caddy(connected_peer, admin_client): + """The cell hostname (e.g. pic0.lan) resolves to Caddy.""" + dns_ip = _dns_ip(admin_client) + dom = _domain(admin_client) + name = _cell_name(admin_client) + fqdn = f'{name}.{dom}' + result = subprocess.run( + ['dig', f'@{dns_ip}', fqdn, 'A', '+short', '+time=5'], + capture_output=True, text=True, timeout=10, + ) + resolved = result.stdout.strip() + assert resolved == '172.20.0.2', ( + f"{fqdn} should resolve to Caddy (172.20.0.2); got {resolved!r}" + ) + + +def test_api_domain_does_not_resolve_to_api_container(connected_peer, admin_client): + """api. must route through Caddy — API container listens on :3000, not :80.""" + dns_ip = _dns_ip(admin_client) + dom = _domain(admin_client) + result = subprocess.run( + ['dig', f'@{dns_ip}', f'api.{dom}', 'A', '+short', '+time=5'], + capture_output=True, text=True, timeout=10, + ) + resolved = result.stdout.strip() + assert resolved != '172.20.0.10', ( + f"api.{dom} resolves to 172.20.0.10 (API container direct) — " + "this bypasses Caddy so port-80 requests return nothing; must be Caddy 172.20.0.2" + ) + assert resolved == '172.20.0.2', f"api.{dom} should be Caddy 172.20.0.2; got {resolved}" + + +def test_webui_domain_does_not_resolve_to_webui_container(connected_peer, admin_client): + """webui. must route through Caddy.""" + dns_ip = _dns_ip(admin_client) + dom = _domain(admin_client) + result = subprocess.run( + ['dig', f'@{dns_ip}', f'webui.{dom}', 'A', '+short', '+time=5'], + capture_output=True, text=True, timeout=10, + ) + resolved = result.stdout.strip() + assert resolved == '172.20.0.2', f"webui.{dom} should be Caddy 172.20.0.2; got {resolved}" + + +# ── Scenario 31: HTTP via IP ─────────────────────────────────────────────────── + +def test_caddy_ip_serves_http(connected_peer): + """Caddy at 172.20.0.2 returns an HTTP response through the VPN.""" + result = subprocess.run( + ['curl', '-s', '-o', '/dev/null', '-w', '%{http_code}', '--connect-timeout', '5', + 'http://172.20.0.2/'], + capture_output=True, text=True, timeout=10, + ) + code = result.stdout.strip() + assert code not in ('000', ''), f"No HTTP response from 172.20.0.2; curl exit {result.returncode}" + + +# ── Scenario 32: HTTP via domain ────────────────────────────────────────────── + +def test_http_api_domain_reaches_api(connected_peer, admin_client): + """curl http://api./api/status returns a JSON response via Caddy + CoreDNS.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + result = subprocess.run( + ['curl', '-s', '--connect-timeout', '5', + '--dns-servers', dns_ip, + f'http://api.{dom}/api/status'], + capture_output=True, text=True, timeout=10, + ) + assert result.stdout.strip(), ( + f"curl http://api.{dom}/api/status returned no output via DNS {dns_ip}. " + f"stderr: {result.stderr[:200]}" + ) + + +# ── Scenario 33: Config DNS field ───────────────────────────────────────────── + +def test_peer_services_config_has_coredns_not_vpn_gateway(admin_client, make_peer): + """WireGuard config in /api/peer/services must use CoreDNS IP, not 10.0.0.1.""" + from helpers.api_client import PicAPIClient + import os + + peer = make_peer('e2etest-dns-config', password='DnsTest123!') + peer_client = PicAPIClient(os.environ.get('PIC_API_BASE', 'http://192.168.31.51:3000')) + peer_client.login(peer['name'], 'DnsTest123!') + + r = peer_client.get('/api/peer/services') + assert r.status_code == 200, f"peer services returned {r.status_code}: {r.text}" + data = r.json() + + dns = data.get('wireguard', {}).get('dns', '') + assert dns != '10.0.0.1', ( + "wireguard.dns is 10.0.0.1 — this is the WireGuard VPN gateway, not a DNS server; " + "VPN clients using this as DNS will fail to resolve all domain names" + ) + + config = data.get('wireguard', {}).get('config', '') + if config: + assert 'DNS = 10.0.0.1' not in config, ( + "WireGuard client config has DNS = 10.0.0.1 — " + "VPN clients will fail to resolve domain names" + ) + for line in config.splitlines(): + if line.strip().startswith('DNS ='): + dns_from_config = line.split('=', 1)[1].strip() + assert dns_from_config.startswith('172.'), ( + f"DNS in config is {dns_from_config} — expected a 172.x.x.x Docker IP; " + "CoreDNS lives on the Docker bridge, not the WireGuard VPN subnet" + ) + break + + +def test_peer_services_caldav_url_uses_configured_domain(admin_client, make_peer): + """CalDAV URL must use the configured domain, not hardcode 'radicale.dev:5232'.""" + from helpers.api_client import PicAPIClient + import os + + dom = _domain(admin_client) + peer = make_peer('e2etest-caldav-url', password='CaldavTest123!') + peer_client = PicAPIClient(os.environ.get('PIC_API_BASE', 'http://192.168.31.51:3000')) + peer_client.login(peer['name'], 'CaldavTest123!') + + r = peer_client.get('/api/peer/services') + assert r.status_code == 200 + url = r.json().get('caldav', {}).get('url', '') + + assert f'calendar.{dom}' in url, ( + f"CalDAV URL {url!r} does not contain 'calendar.{dom}' — " + f"must use configured domain '{dom}', not a hardcoded TLD" + ) + assert 'radicale' not in url, ( + f"CalDAV URL {url!r} contains 'radicale' — no radicale. DNS record exists" + ) + assert ':5232' not in url, ( + f"CalDAV URL {url!r} exposes internal port 5232 — use Caddy-proxied URL" + ) + + +# ── Scenario 34: DNS reachability from VPN ──────────────────────────────────── + +def test_coredns_reachable_via_vpn(connected_peer, admin_client): + """CoreDNS is reachable through the WireGuard VPN tunnel.""" + dns_ip = _dns_ip(admin_client) + result = subprocess.run( + ['dig', f'@{dns_ip}', 'health.check', '+time=3', '+tries=1'], + capture_output=True, text=True, timeout=8, + ) + # NXDOMAIN means DNS responded — connectivity is what we test here + responded = 'status:' in result.stdout or result.returncode in (0, 9) + assert responded, ( + f"CoreDNS at {dns_ip} did not respond via VPN tunnel. " + f"Check that peer AllowedIPs covers the Docker network or 0.0.0.0/0. " + f"stdout: {result.stdout[:200]}" + ) diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py index 9ea6c9c..6e7ee71 100644 --- a/tests/test_api_endpoints.py +++ b/tests/test_api_endpoints.py @@ -366,8 +366,8 @@ class TestAPIEndpoints(unittest.TestCase): def test_email_endpoints(self, mock_email): # Ensure all relevant mock methods return JSON-serializable values mock_email.get_users.return_value = [{'username': 'user1', 'domain': 'cell', 'email': 'user1@cell'}] - mock_email.create_user.return_value = True - mock_email.delete_user.return_value = True + mock_email.create_email_user.return_value = True + mock_email.delete_email_user.return_value = True mock_email.get_status.return_value = {'postfix_running': True, 'dovecot_running': True, 'total_users': 1, 'total_size_bytes': 0, 'total_size_mb': 0.0, 'users': [{'username': 'user1', 'domain': 'cell', 'email': 'user1@cell'}]} mock_email.test_connectivity.return_value = {'smtp': {'success': True, 'message': 'SMTP server responding'}} mock_email.send_email.return_value = True @@ -383,17 +383,17 @@ class TestAPIEndpoints(unittest.TestCase): # /api/email/users (POST) response = self.client.post('/api/email/users', data=json.dumps({'username': 'user1', 'domain': 'cell', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 200) - mock_email.create_user.side_effect = Exception('fail') + mock_email.create_email_user.side_effect = Exception('fail') response = self.client.post('/api/email/users', data=json.dumps({'username': 'user1', 'domain': 'cell', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 500) - mock_email.create_user.side_effect = None + mock_email.create_email_user.side_effect = None # /api/email/users/ (DELETE) response = self.client.delete('/api/email/users/user1') self.assertEqual(response.status_code, 200) - mock_email.delete_user.side_effect = Exception('fail') + mock_email.delete_email_user.side_effect = Exception('fail') response = self.client.delete('/api/email/users/user1') self.assertEqual(response.status_code, 500) - mock_email.delete_user.side_effect = None + mock_email.delete_email_user.side_effect = None # /api/email/status (GET) response = self.client.get('/api/email/status') self.assertEqual(response.status_code, 200) @@ -427,8 +427,8 @@ class TestAPIEndpoints(unittest.TestCase): def test_calendar_endpoints(self, mock_calendar): # Mock return values for all relevant calendar_manager methods mock_calendar.get_users.return_value = [{'username': 'user1', 'collections': {'calendars': ['cal1'], 'contacts': ['c1']}}] - mock_calendar.create_user.return_value = True - mock_calendar.delete_user.return_value = True + mock_calendar.create_calendar_user.return_value = True + mock_calendar.delete_calendar_user.return_value = True mock_calendar.create_calendar.return_value = {'calendar': 'cal1'} mock_calendar.add_event.return_value = {'event': 'event1'} mock_calendar.get_events.return_value = [{'event': 'event1'}] @@ -445,17 +445,17 @@ class TestAPIEndpoints(unittest.TestCase): # /api/calendar/users (POST) response = self.client.post('/api/calendar/users', data=json.dumps({'username': 'user1', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 200) - mock_calendar.create_user.side_effect = Exception('fail') + mock_calendar.create_calendar_user.side_effect = Exception('fail') response = self.client.post('/api/calendar/users', data=json.dumps({'username': 'user1', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 500) - mock_calendar.create_user.side_effect = None + mock_calendar.create_calendar_user.side_effect = None # /api/calendar/users/ (DELETE) response = self.client.delete('/api/calendar/users/user1') self.assertEqual(response.status_code, 200) - mock_calendar.delete_user.side_effect = Exception('fail') + mock_calendar.delete_calendar_user.side_effect = Exception('fail') response = self.client.delete('/api/calendar/users/user1') self.assertEqual(response.status_code, 500) - mock_calendar.delete_user.side_effect = None + mock_calendar.delete_calendar_user.side_effect = None # /api/calendar/calendars (POST) response = self.client.post('/api/calendar/calendars', data=json.dumps({'username': 'user1', 'calendar_name': 'cal1'}), content_type='application/json') self.assertEqual(response.status_code, 200) @@ -599,10 +599,10 @@ class TestAPIEndpoints(unittest.TestCase): self.assertEqual(response.status_code, 500) mock_routing.get_firewall_rules.side_effect = None # /api/routing/peers (POST) - response = self.client.post('/api/routing/peers', data=json.dumps({'peer': 'peer1', 'route': '10.0.0.2'}), content_type='application/json') + response = self.client.post('/api/routing/peers', data=json.dumps({'peer_name': 'peer1', 'peer_ip': '10.0.0.2'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_peer_route.side_effect = Exception('fail') - response = self.client.post('/api/routing/peers', data=json.dumps({'peer': 'peer1', 'route': '10.0.0.2'}), content_type='application/json') + response = self.client.post('/api/routing/peers', data=json.dumps({'peer_name': 'peer1', 'peer_ip': '10.0.0.2'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_peer_route.side_effect = None # /api/routing/peers (GET) @@ -620,24 +620,24 @@ class TestAPIEndpoints(unittest.TestCase): self.assertEqual(response.status_code, 500) mock_routing.remove_peer_route.side_effect = None # /api/routing/exit-nodes (POST) - response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'node': 'exit1'}), content_type='application/json') + response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'peer_name': 'exit1', 'peer_ip': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_exit_node.side_effect = Exception('fail') - response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'node': 'exit1'}), content_type='application/json') + response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'peer_name': 'exit1', 'peer_ip': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_exit_node.side_effect = None # /api/routing/bridge (POST) - response = self.client.post('/api/routing/bridge', data=json.dumps({'bridge': 'br1'}), content_type='application/json') + response = self.client.post('/api/routing/bridge', data=json.dumps({'source_peer': 'peer1', 'target_peer': 'peer2'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_bridge_route.side_effect = Exception('fail') - response = self.client.post('/api/routing/bridge', data=json.dumps({'bridge': 'br1'}), content_type='application/json') + response = self.client.post('/api/routing/bridge', data=json.dumps({'source_peer': 'peer1', 'target_peer': 'peer2'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_bridge_route.side_effect = None # /api/routing/split (POST) - response = self.client.post('/api/routing/split', data=json.dumps({'split': 'sp1'}), content_type='application/json') + response = self.client.post('/api/routing/split', data=json.dumps({'network': '10.0.0.0/24', 'exit_peer': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_split_route.side_effect = Exception('fail') - response = self.client.post('/api/routing/split', data=json.dumps({'split': 'sp1'}), content_type='application/json') + response = self.client.post('/api/routing/split', data=json.dumps({'network': '10.0.0.0/24', 'exit_peer': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_split_route.side_effect = None # /api/routing/connectivity (POST) diff --git a/tests/test_app_misc.py b/tests/test_app_misc.py index e326921..8b6e8c5 100644 --- a/tests/test_app_misc.py +++ b/tests/test_app_misc.py @@ -113,8 +113,11 @@ class TestAppMisc(unittest.TestCase): self.assertFalse(app_module.is_local_request()) def test_is_local_request_private_ip(self): + # 192.168.x.x (LAN) is no longer trusted — only Docker bridge (172.16.0.0/12) + # and loopback are trusted. The API is bound to 127.0.0.1:3000 and only + # reachable via Caddy (172.20.x.x), so LAN IPs never reach it directly. with patch('app.request', new=self._req('192.168.1.5')): - self.assertTrue(app_module.is_local_request()) + self.assertFalse(app_module.is_local_request()) def test_is_local_request_xff_spoof_rejected(self): # Client sends X-Forwarded-For: 127.0.0.1 but actual IP is public @@ -123,8 +126,14 @@ class TestAppMisc(unittest.TestCase): self.assertFalse(app_module.is_local_request()) def test_is_local_request_xff_last_entry_local(self): - # Caddy appends the real client IP; last entry is local → allow + # 192.168.x.x is no longer in the trusted range — only Docker bridge + # (172.16.0.0/12) and loopback are trusted now. with patch('app.request', new=self._req('8.8.8.8', xff='8.8.8.8, 192.168.1.10')): + self.assertFalse(app_module.is_local_request()) + + def test_is_local_request_xff_docker_bridge(self): + # Docker bridge IPs (172.16.0.0/12) ARE trusted — Caddy uses this range + with patch('app.request', new=self._req('8.8.8.8', xff='8.8.8.8, 172.20.0.2')): self.assertTrue(app_module.is_local_request()) def test_is_local_request_xff_single_public_rejected(self): diff --git a/tests/test_calendar_endpoints.py b/tests/test_calendar_endpoints.py index 1f55537..b5ea38c 100644 --- a/tests/test_calendar_endpoints.py +++ b/tests/test_calendar_endpoints.py @@ -1 +1,379 @@ -# ... moved and adapted code from test_phase3_endpoints.py (calendar section) ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for calendar Flask endpoints in api/app.py. + +Covers: + GET /api/calendar/users + POST /api/calendar/users + DELETE /api/calendar/users/ + POST /api/calendar/calendars + POST /api/calendar/events + GET /api/calendar/events// + GET /api/calendar/status + GET /api/calendar/connectivity +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetCalendarUsers(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_get_users_returns_200_with_list(self, mock_cm): + mock_cm.get_users.return_value = [ + {'username': 'alice', 'email': 'alice@cell'}, + {'username': 'bob', 'email': 'bob@cell'}, + ] + r = self.client.get('/api/calendar/users') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.calendar_manager') + def test_get_users_returns_200_with_empty_list(self, mock_cm): + mock_cm.get_users.return_value = [] + r = self.client.get('/api/calendar/users') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.calendar_manager') + def test_get_users_returns_500_on_exception(self, mock_cm): + mock_cm.get_users.side_effect = Exception('radicale unreachable') + r = self.client.get('/api/calendar/users') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCreateCalendarUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_create_user_returns_200_on_valid_body(self, mock_cm): + mock_cm.create_calendar_user.return_value = True + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.calendar_manager') + def test_create_user_passes_credentials_to_manager(self, mock_cm): + mock_cm.create_calendar_user.return_value = True + self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + mock_cm.create_calendar_user.assert_called_once_with('alice', 'secret123') + + @patch('app.calendar_manager') + def test_create_user_returns_400_when_no_body(self, mock_cm): + r = self.client.post('/api/calendar/users') + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + mock_cm.create_calendar_user.assert_not_called() + + @patch('app.calendar_manager') + def test_create_user_returns_400_when_username_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.create_calendar_user.assert_not_called() + + @patch('app.calendar_manager') + def test_create_user_returns_400_when_password_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.create_calendar_user.assert_not_called() + + @patch('app.calendar_manager') + def test_create_user_returns_500_on_exception(self, mock_cm): + mock_cm.create_calendar_user.side_effect = Exception('htpasswd write failure') + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteCalendarUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_delete_user_returns_200_on_success(self, mock_cm): + mock_cm.delete_calendar_user.return_value = True + r = self.client.delete('/api/calendar/users/alice') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('deleted', data) + + @patch('app.calendar_manager') + def test_delete_user_passes_username_to_manager(self, mock_cm): + mock_cm.delete_calendar_user.return_value = True + self.client.delete('/api/calendar/users/bob') + mock_cm.delete_calendar_user.assert_called_once_with('bob') + + @patch('app.calendar_manager') + def test_delete_user_returns_500_on_exception(self, mock_cm): + mock_cm.delete_calendar_user.side_effect = Exception('user not found') + r = self.client.delete('/api/calendar/users/alice') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCreateCalendar(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_create_calendar_returns_200_on_valid_body(self, mock_cm): + mock_cm.create_calendar.return_value = True + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice', 'name': 'Work'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.calendar_manager') + def test_create_calendar_accepts_calendar_name_alias(self, mock_cm): + mock_cm.create_calendar.return_value = True + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice', 'calendar_name': 'Personal'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + + @patch('app.calendar_manager') + def test_create_calendar_returns_400_when_no_body(self, mock_cm): + r = self.client.post('/api/calendar/calendars') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.create_calendar.assert_not_called() + + @patch('app.calendar_manager') + def test_create_calendar_returns_400_when_username_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'name': 'Work'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_create_calendar_returns_400_when_name_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_create_calendar_returns_500_on_exception(self, mock_cm): + mock_cm.create_calendar.side_effect = Exception('CalDAV error') + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice', 'name': 'Work'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddCalendarEvent(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_add_event_returns_200_on_valid_body(self, mock_cm): + mock_cm.add_event.return_value = 'event-uid-123' + r = self.client.post( + '/api/calendar/events', + data=json.dumps({ + 'username': 'alice', + 'calendar_name': 'Work', + 'summary': 'Team Meeting', + 'dtstart': '20260427T100000Z', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.calendar_manager') + def test_add_event_returns_400_when_no_body(self, mock_cm): + r = self.client.post('/api/calendar/events') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.add_event.assert_not_called() + + @patch('app.calendar_manager') + def test_add_event_returns_400_when_username_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/events', + data=json.dumps({'calendar_name': 'Work', 'summary': 'Meeting'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_add_event_returns_400_when_calendar_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/events', + data=json.dumps({'username': 'alice', 'summary': 'Meeting'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_add_event_returns_500_on_exception(self, mock_cm): + mock_cm.add_event.side_effect = Exception('iCalendar parse error') + r = self.client.post( + '/api/calendar/events', + data=json.dumps({ + 'username': 'alice', + 'calendar_name': 'Work', + 'summary': 'Meeting', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetCalendarEvents(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_get_events_returns_200_with_events(self, mock_cm): + mock_cm.get_events.return_value = [ + {'uid': 'abc', 'summary': 'Standup', 'dtstart': '20260427T090000Z'}, + ] + r = self.client.get('/api/calendar/events/alice/Work') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 1) + + @patch('app.calendar_manager') + def test_get_events_passes_username_and_calendar_to_manager(self, mock_cm): + mock_cm.get_events.return_value = [] + self.client.get('/api/calendar/events/bob/Personal') + mock_cm.get_events.assert_called_once() + args = mock_cm.get_events.call_args[0] + self.assertEqual(args[0], 'bob') + self.assertEqual(args[1], 'Personal') + + @patch('app.calendar_manager') + def test_get_events_returns_500_on_exception(self, mock_cm): + mock_cm.get_events.side_effect = Exception('calendar not found') + r = self.client.get('/api/calendar/events/alice/Work') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetCalendarStatus(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_get_status_returns_200_with_status_dict(self, mock_cm): + mock_cm.get_status.return_value = { + 'running': True, + 'port': 5232, + 'users_count': 3, + } + r = self.client.get('/api/calendar/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('running', data) + + @patch('app.calendar_manager') + def test_get_status_returns_500_on_exception(self, mock_cm): + mock_cm.get_status.side_effect = Exception('container not found') + r = self.client.get('/api/calendar/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCalendarConnectivity(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_connectivity_returns_200_with_result(self, mock_cm): + mock_cm.test_connectivity.return_value = { + 'caldav': True, + 'carddav': True, + 'latency_ms': 8, + } + r = self.client.get('/api/calendar/connectivity') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('caldav', data) + + @patch('app.calendar_manager') + def test_connectivity_returns_500_on_exception(self, mock_cm): + mock_cm.test_connectivity.side_effect = Exception('connection refused') + r = self.client.get('/api/calendar/connectivity') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_cell_link_dns.py b/tests/test_cell_link_dns.py new file mode 100644 index 0000000..297d5d8 --- /dev/null +++ b/tests/test_cell_link_dns.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +""" +Tests for cell-to-cell DNS forwarding integration. + +Covers: + - generate_corefile() with cell_links entries + - apply_all_dns_rules() passing cell_links through to generate_corefile() + - Correct domain/dns_ip values in the emitted forwarding stanza + - Validation: invalid characters in domain are rejected by add_cell_dns_forward() +""" + +import sys +import os +import tempfile +import shutil +import unittest +from unittest.mock import patch, MagicMock, call +from pathlib import Path + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +import firewall_manager + + +# --------------------------------------------------------------------------- +# generate_corefile() with cell_links +# --------------------------------------------------------------------------- + +class TestGenerateCorefileOneLink(unittest.TestCase): + """generate_corefile() with a single cell link produces the right stanza.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.path = os.path.join(self.tmp, 'Corefile') + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _read(self): + return open(self.path).read() + + def test_forwarding_block_present(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('remote.cell {', content) + + def test_correct_dns_ip_in_forward_directive(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('forward . 10.5.0.1', content) + + def test_cache_directive_present_in_forwarding_block(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + # 'cache' must appear in the forwarding block (after the primary zone block) + idx_primary = content.index('remote.cell {') + self.assertIn('cache', content[idx_primary:]) + + def test_log_directive_present_in_forwarding_block(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + idx_primary = content.index('remote.cell {') + self.assertIn('log', content[idx_primary:]) + + def test_forwarding_block_appears_after_primary_zone(self): + """The cell link stanza must appear after the primary zone block, not inside it.""" + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + # Primary zone ends with its closing brace; remote.cell block follows + idx_primary_zone = content.index('cell {') + idx_forward_block = content.index('remote.cell {') + self.assertGreater(idx_forward_block, idx_primary_zone) + + +class TestGenerateCorefileMultipleLinks(unittest.TestCase): + """generate_corefile() with multiple cell links produces one stanza each.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.path = os.path.join(self.tmp, 'Corefile') + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _read(self): + return open(self.path).read() + + def test_all_domains_present(self): + cell_links = [ + {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, + {'domain': 'gamma.cell', 'dns_ip': '10.3.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('alpha.cell {', content) + self.assertIn('beta.cell {', content) + self.assertIn('gamma.cell {', content) + + def test_all_dns_ips_present(self): + cell_links = [ + {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('forward . 10.1.0.1', content) + self.assertIn('forward . 10.2.0.1', content) + + def test_stanza_count_matches_link_count(self): + """Each valid link contributes exactly one forwarding stanza.""" + cell_links = [ + {'domain': 'a.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'b.cell', 'dns_ip': '10.2.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + # Count occurrences of 'forward .' — one for default, one per cell link + count = content.count('forward .') + self.assertEqual(count, 3) # 1 default + 2 cell links + + +# --------------------------------------------------------------------------- +# apply_all_dns_rules() passes cell_links through to generate_corefile() +# --------------------------------------------------------------------------- + +class TestApplyAllDnsRulesPassesCellLinks(unittest.TestCase): + """apply_all_dns_rules() must forward the cell_links argument to generate_corefile().""" + + def test_cell_links_forwarded(self): + cell_links = [{'domain': 'x.cell', 'dns_ip': '10.9.0.1'}] + with patch.object(firewall_manager, 'generate_corefile', return_value=True) as mock_gen, \ + patch.object(firewall_manager, 'reload_coredns', return_value=True): + firewall_manager.apply_all_dns_rules( + peers=[], + corefile_path='/tmp/fake_Corefile', + domain='cell', + cell_links=cell_links, + ) + mock_gen.assert_called_once_with( + [], '/tmp/fake_Corefile', 'cell', cell_links + ) + + def test_cell_links_none_forwarded_as_none(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=True) as mock_gen, \ + patch.object(firewall_manager, 'reload_coredns', return_value=True): + firewall_manager.apply_all_dns_rules( + peers=[], + corefile_path='/tmp/fake_Corefile', + domain='cell', + cell_links=None, + ) + mock_gen.assert_called_once_with([], '/tmp/fake_Corefile', 'cell', None) + + def test_reload_called_on_success(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=True), \ + patch.object(firewall_manager, 'reload_coredns', return_value=True) as mock_reload: + firewall_manager.apply_all_dns_rules([], '/tmp/f', cell_links=None) + mock_reload.assert_called_once() + + def test_reload_not_called_on_failure(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=False), \ + patch.object(firewall_manager, 'reload_coredns') as mock_reload: + firewall_manager.apply_all_dns_rules([], '/tmp/f', cell_links=None) + mock_reload.assert_not_called() + + +# --------------------------------------------------------------------------- +# Domain validation in add_cell_dns_forward() (via network_manager) +# --------------------------------------------------------------------------- + +class TestAddCellDnsForwardValidation(unittest.TestCase): + """ + add_cell_dns_forward() must reject malformed domains/IPs without writing + the Corefile or calling apply_all_dns_rules(). + """ + + def _get_network_manager(self, tmp_dir): + """Construct a minimal NetworkManager with test directories.""" + # We import here so the test file doesn't hard-fail if network_manager + # has an import-time dependency that's unavailable in CI. + try: + from network_manager import NetworkManager + except ImportError as e: + self.skipTest(f'NetworkManager import failed: {e}') + os.makedirs(os.path.join(tmp_dir, 'dns'), exist_ok=True) + return NetworkManager(data_dir=tmp_dir, config_dir=tmp_dir) + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_invalid_dns_ip_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('valid.cell', 'not-an-ip') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_domain_with_newline_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('evil\ndomain', '10.1.0.1') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_domain_with_braces_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('evil{domain}', '10.1.0.1') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_domain_with_space_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('evil domain', '10.1.0.1') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_valid_domain_and_ip_calls_apply_all_dns_rules(self): + """Valid inputs must call firewall_manager.apply_all_dns_rules().""" + nm = self._get_network_manager(self.tmp) + with patch.object(firewall_manager, 'apply_all_dns_rules', return_value=True) as mock_apply, \ + patch.object(firewall_manager, 'reload_coredns', return_value=True): + result = nm.add_cell_dns_forward('valid.cell', '10.1.0.1') + mock_apply.assert_called_once() + call_kwargs = mock_apply.call_args + # cell_links kwarg must include the new entry + cell_links_arg = call_kwargs[1].get('cell_links') or call_kwargs[0][3] + domains = [l['domain'] for l in cell_links_arg] + self.assertIn('valid.cell', domains) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_cells_endpoints.py b/tests/test_cells_endpoints.py new file mode 100644 index 0000000..6a2c351 --- /dev/null +++ b/tests/test_cells_endpoints.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +""" +Unit tests for cell management Flask endpoints in api/app.py. + +Covers: + GET /api/cells/invite — generate invite package + GET /api/cells — list connected cells + POST /api/cells — connect to a remote cell + DELETE /api/cells/ — disconnect from a cell + GET /api/cells//status — live status for a connected cell +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + +# Minimal set of required fields for POST /api/cells +_VALID_CELL_BODY = { + 'cell_name': 'remotecell', + 'public_key': 'abc123publickey==', + 'vpn_subnet': '10.1.0.0/24', + 'dns_ip': '10.1.0.1', + 'domain': 'remotecell.cell', +} + + +class TestGetCellInvite(unittest.TestCase): + """GET /api/cells/invite""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + @patch('app.config_manager') + def test_get_invite_returns_200_with_invite_dict(self, mock_cfg, mock_clm): + mock_cfg.configs = {'_identity': {'cell_name': 'mycell', 'domain': 'cell'}} + mock_clm.generate_invite.return_value = { + 'cell_name': 'mycell', + 'public_key': 'server_pub_key==', + 'vpn_subnet': '10.0.0.0/24', + 'dns_ip': '10.0.0.1', + 'domain': 'cell', + } + r = self.client.get('/api/cells/invite') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('cell_name', data) + self.assertIn('public_key', data) + + @patch('app.cell_link_manager') + @patch('app.config_manager') + def test_get_invite_passes_cell_name_and_domain(self, mock_cfg, mock_clm): + mock_cfg.configs = {'_identity': {'cell_name': 'myhome', 'domain': 'home'}} + mock_clm.generate_invite.return_value = {} + self.client.get('/api/cells/invite') + mock_clm.generate_invite.assert_called_once_with('myhome', 'home') + + @patch('app.cell_link_manager') + @patch('app.config_manager') + def test_get_invite_returns_500_on_exception(self, mock_cfg, mock_clm): + mock_cfg.configs = {'_identity': {}} + mock_clm.generate_invite.side_effect = Exception('WireGuard key unavailable') + r = self.client.get('/api/cells/invite') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestListCellConnections(unittest.TestCase): + """GET /api/cells""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_list_cells_returns_200_with_list(self, mock_clm): + mock_clm.list_connections.return_value = [ + {'cell_name': 'remotecell', 'domain': 'remotecell.cell', 'status': 'connected'}, + ] + r = self.client.get('/api/cells') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 1) + self.assertEqual(data[0]['cell_name'], 'remotecell') + + @patch('app.cell_link_manager') + def test_list_cells_returns_empty_list_when_none_connected(self, mock_clm): + mock_clm.list_connections.return_value = [] + r = self.client.get('/api/cells') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.cell_link_manager') + def test_list_cells_returns_500_on_exception(self, mock_clm): + mock_clm.list_connections.side_effect = Exception('storage error') + r = self.client.get('/api/cells') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddCellConnection(unittest.TestCase): + """POST /api/cells""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_add_cell_returns_201_on_success(self, mock_clm): + mock_clm.add_connection.return_value = {'cell_name': 'remotecell'} + r = self.client.post( + '/api/cells', + data=json.dumps(_VALID_CELL_BODY), + content_type='application/json', + ) + self.assertEqual(r.status_code, 201) + data = json.loads(r.data) + self.assertIn('message', data) + self.assertIn('link', data) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_no_body(self, mock_clm): + r = self.client.post('/api/cells') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_clm.add_connection.assert_not_called() + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_cell_name_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'cell_name'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_public_key_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'public_key'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_vpn_subnet_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'vpn_subnet'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_dns_ip_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'dns_ip'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_domain_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'domain'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_on_value_error_from_manager(self, mock_clm): + mock_clm.add_connection.side_effect = ValueError('cell already connected') + r = self.client.post( + '/api/cells', + data=json.dumps(_VALID_CELL_BODY), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_500_on_unexpected_exception(self, mock_clm): + mock_clm.add_connection.side_effect = Exception('WireGuard peer add failed') + r = self.client.post( + '/api/cells', + data=json.dumps(_VALID_CELL_BODY), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestRemoveCellConnection(unittest.TestCase): + """DELETE /api/cells/""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_remove_cell_returns_200_on_success(self, mock_clm): + mock_clm.remove_connection.return_value = None + r = self.client.delete('/api/cells/remotecell') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.cell_link_manager') + def test_remove_cell_passes_cell_name_to_manager(self, mock_clm): + mock_clm.remove_connection.return_value = None + self.client.delete('/api/cells/faraway') + mock_clm.remove_connection.assert_called_once_with('faraway') + + @patch('app.cell_link_manager') + def test_remove_cell_returns_404_on_value_error(self, mock_clm): + mock_clm.remove_connection.side_effect = ValueError('cell not found') + r = self.client.delete('/api/cells/nonexistent') + self.assertEqual(r.status_code, 404) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_remove_cell_returns_500_on_unexpected_exception(self, mock_clm): + mock_clm.remove_connection.side_effect = Exception('storage corruption') + r = self.client.delete('/api/cells/remotecell') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetCellConnectionStatus(unittest.TestCase): + """GET /api/cells//status""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_get_cell_status_returns_200_with_status_dict(self, mock_clm): + mock_clm.get_connection_status.return_value = { + 'cell_name': 'remotecell', + 'online': True, + 'last_handshake': '2026-04-27T09:00:00Z', + 'transfer_rx': 1024, + 'transfer_tx': 2048, + } + r = self.client.get('/api/cells/remotecell/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('online', data) + self.assertTrue(data['online']) + + @patch('app.cell_link_manager') + def test_get_cell_status_passes_cell_name(self, mock_clm): + mock_clm.get_connection_status.return_value = {} + self.client.get('/api/cells/faraway/status') + mock_clm.get_connection_status.assert_called_once_with('faraway') + + @patch('app.cell_link_manager') + def test_get_cell_status_returns_404_on_value_error(self, mock_clm): + mock_clm.get_connection_status.side_effect = ValueError('cell not found') + r = self.client.get('/api/cells/missing/status') + self.assertEqual(r.status_code, 404) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_get_cell_status_returns_500_on_unexpected_exception(self, mock_clm): + mock_clm.get_connection_status.side_effect = Exception('WireGuard query failed') + r = self.client.get('/api/cells/remotecell/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_email_endpoints.py b/tests/test_email_endpoints.py index 0ce01fe..57cbd9e 100644 --- a/tests/test_email_endpoints.py +++ b/tests/test_email_endpoints.py @@ -1 +1,212 @@ -# ... moved and adapted code from test_phase3_endpoints.py (email section) ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for email Flask endpoints in api/app.py. + +Covers: + GET /api/email/users + POST /api/email/users + DELETE /api/email/users/ + GET /api/email/status + GET /api/email/connectivity +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetEmailUsers(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_get_users_returns_200_with_list(self, mock_em): + mock_em.get_users.return_value = [ + {'username': 'alice@cell', 'domain': 'cell'}, + {'username': 'bob@cell', 'domain': 'cell'}, + ] + r = self.client.get('/api/email/users') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.email_manager') + def test_get_users_returns_empty_list_when_no_users(self, mock_em): + mock_em.get_users.return_value = [] + r = self.client.get('/api/email/users') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.email_manager') + def test_get_users_returns_500_on_exception(self, mock_em): + mock_em.get_users.side_effect = Exception('mailbox unreachable') + r = self.client.get('/api/email/users') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCreateEmailUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_create_user_returns_200_on_success(self, mock_em): + mock_em.create_email_user.return_value = True + r = self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.email_manager') + def test_create_user_calls_manager_with_username_and_password(self, mock_em): + mock_em.create_email_user.return_value = True + self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + mock_em.create_email_user.assert_called_once() + args = mock_em.create_email_user.call_args[0] + self.assertEqual(args[0], 'alice') + self.assertEqual(args[2], 'secret123') + + @patch('app.email_manager') + def test_create_user_returns_400_when_username_missing(self, mock_em): + r = self.client.post( + '/api/email/users', + data=json.dumps({'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_em.create_email_user.assert_not_called() + + @patch('app.email_manager') + def test_create_user_returns_400_when_password_missing(self, mock_em): + r = self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_em.create_email_user.assert_not_called() + + @patch('app.email_manager') + def test_create_user_returns_400_when_no_body(self, mock_em): + r = self.client.post('/api/email/users') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.email_manager') + def test_create_user_returns_500_on_exception(self, mock_em): + mock_em.create_email_user.side_effect = Exception('SMTP config error') + r = self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteEmailUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_delete_user_returns_200_on_success(self, mock_em): + mock_em.delete_email_user.return_value = True + r = self.client.delete('/api/email/users/alice') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('deleted', data) + + @patch('app.email_manager') + def test_delete_user_calls_manager_with_username(self, mock_em): + mock_em.delete_email_user.return_value = True + self.client.delete('/api/email/users/bob') + mock_em.delete_email_user.assert_called_once() + args = mock_em.delete_email_user.call_args[0] + self.assertEqual(args[0], 'bob') + + @patch('app.email_manager') + def test_delete_user_returns_500_on_exception(self, mock_em): + mock_em.delete_email_user.side_effect = Exception('LDAP error') + r = self.client.delete('/api/email/users/alice') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetEmailStatus(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_get_status_returns_200_with_status_dict(self, mock_em): + mock_em.get_status.return_value = { + 'running': True, + 'smtp_port': 25, + 'imap_port': 993, + } + r = self.client.get('/api/email/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('running', data) + + @patch('app.email_manager') + def test_get_status_returns_500_on_exception(self, mock_em): + mock_em.get_status.side_effect = Exception('container not found') + r = self.client.get('/api/email/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestEmailConnectivity(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_connectivity_returns_200_with_result(self, mock_em): + mock_em.test_connectivity.return_value = { + 'smtp': True, + 'imap': True, + 'latency_ms': 12, + } + r = self.client.get('/api/email/connectivity') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('smtp', data) + + @patch('app.email_manager') + def test_connectivity_returns_500_on_exception(self, mock_em): + mock_em.test_connectivity.side_effect = Exception('timeout') + r = self.client.get('/api/email/connectivity') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_enforce_auth_configured.py b/tests/test_enforce_auth_configured.py new file mode 100644 index 0000000..075832c --- /dev/null +++ b/tests/test_enforce_auth_configured.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" +Tests for the enforce_auth before_request hook in api/app.py. + +The hook has two distinct behaviours depending on the auth store state: + - users file exists and is POPULATED → auth is enforced (unauthenticated → 401) + - users file exists but is EMPTY → 503 (auth not configured) + - users file does not exist / unreadable → bypass (pre-auth compat mode) + +These tests create real AuthManager instances pointing at tmp directories so +that list_users() and the file-readability check both behave exactly as they +do in production. +""" + +import os +import sys +import json +import pytest +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, str(Path(__file__).parent.parent / 'api')) + + +@pytest.fixture +def flask_client(): + from app import app + app.config['TESTING'] = True + return app.test_client() + + +@pytest.fixture +def populated_auth_manager(tmp_path): + """AuthManager whose users file contains at least one admin account.""" + from auth_manager import AuthManager + data_dir = str(tmp_path / 'data') + config_dir = str(tmp_path / 'config') + os.makedirs(data_dir, exist_ok=True) + os.makedirs(config_dir, exist_ok=True) + mgr = AuthManager(data_dir=data_dir, config_dir=config_dir) + # Create an admin so list_users() is non-empty + ok = mgr.create_user('admin', 'AdminPass123!', 'admin') + assert ok, 'Could not seed admin user for test' + return mgr + + +@pytest.fixture +def empty_auth_manager(tmp_path): + """AuthManager whose users file exists and is readable but contains no users.""" + from auth_manager import AuthManager + data_dir = str(tmp_path / 'data') + config_dir = str(tmp_path / 'config') + os.makedirs(data_dir, exist_ok=True) + os.makedirs(config_dir, exist_ok=True) + mgr = AuthManager(data_dir=data_dir, config_dir=config_dir) + # The constructor creates the file with '[]' (empty list). We do NOT add + # any user, so list_users() returns [] but the file is readable. + assert mgr.list_users() == [], 'Expected empty user list' + return mgr + + +# ── populated store → auth enforced ────────────────────────────────────────── + +def test_populated_auth_manager_unauthenticated_request_gets_401( + flask_client, populated_auth_manager +): + """When the auth store has users, unauthenticated API requests must get 401.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/api/status') + assert r.status_code == 401 + data = json.loads(r.data) + assert 'error' in data + + +def test_populated_auth_manager_401_body_says_not_authenticated( + flask_client, populated_auth_manager +): + """The 401 body must clearly indicate the session is missing.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/api/peers') + assert r.status_code == 401 + data = json.loads(r.data) + assert 'Not authenticated' in data.get('error', '') + + +def test_populated_auth_manager_non_api_path_bypasses_auth( + flask_client, populated_auth_manager +): + """Non-API paths like /health must always be public.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/health') + assert r.status_code == 200 + + +def test_populated_auth_manager_auth_namespace_bypasses_auth( + flask_client, populated_auth_manager +): + """The /api/auth/* namespace must always be accessible without a session.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/api/auth/me') + # /api/auth/me may return 401 from the route itself (no session), but it + # must NOT be blocked by enforce_auth; the enforce_auth hook must return None + # for /api/auth/* paths. The status must not be 503. + assert r.status_code != 503 + + +# ── empty store → 503 ──────────────────────────────────────────────────────── + +def test_empty_auth_manager_returns_503_for_api_requests( + flask_client, empty_auth_manager +): + """When the users file exists and is readable but empty, /api/* must get 503.""" + with patch('app.auth_manager', empty_auth_manager): + r = flask_client.get('/api/status') + assert r.status_code == 503 + data = json.loads(r.data) + assert 'error' in data + + +def test_empty_auth_manager_503_body_mentions_configuration( + flask_client, empty_auth_manager +): + """The 503 error body must mention that auth is not configured.""" + with patch('app.auth_manager', empty_auth_manager): + r = flask_client.get('/api/config') + assert r.status_code == 503 + data = json.loads(r.data) + error_text = data.get('error', '') + assert 'not configured' in error_text.lower() or 'Authentication' in error_text + + +def test_empty_auth_manager_non_api_path_bypasses_503( + flask_client, empty_auth_manager +): + """Even with an empty auth store, /health must remain accessible.""" + with patch('app.auth_manager', empty_auth_manager): + r = flask_client.get('/health') + assert r.status_code == 200 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_file_endpoints.py b/tests/test_file_endpoints.py index 15ba155..e9080f2 100644 --- a/tests/test_file_endpoints.py +++ b/tests/test_file_endpoints.py @@ -231,7 +231,7 @@ class TestFileCreateFolderEndpoint(unittest.TestCase): mock_fm.create_folder.return_value = True r = self.client.post( '/api/files/folders', - data=json.dumps({'username': 'alice', 'folder': 'Archive'}), + data=json.dumps({'username': 'alice', 'folder_path': 'Archive'}), content_type='application/json', ) self.assertEqual(r.status_code, 200) @@ -247,7 +247,7 @@ class TestFileCreateFolderEndpoint(unittest.TestCase): mock_fm.create_folder.side_effect = Exception('quota exceeded') r = self.client.post( '/api/files/folders', - data=json.dumps({'username': 'alice', 'folder': 'NewFolder'}), + data=json.dumps({'username': 'alice', 'folder_path': 'NewFolder'}), content_type='application/json', ) self.assertEqual(r.status_code, 500) diff --git a/tests/test_firewall_manager.py b/tests/test_firewall_manager.py index 024d152..cae337d 100644 --- a/tests/test_firewall_manager.py +++ b/tests/test_firewall_manager.py @@ -30,10 +30,12 @@ def _make_peer(ip, internet=True, services=None, peers=True): class TestPeerComment(unittest.TestCase): def test_dots_replaced_with_dashes(self): - self.assertEqual(firewall_manager._peer_comment('10.0.0.2'), 'pic-peer-10-0-0-2') + # Comment format now includes /32 suffix to prevent substring matches + # (e.g. pic-peer-10-0-0-1/32 is not a prefix of pic-peer-10-0-0-10/32) + self.assertEqual(firewall_manager._peer_comment('10.0.0.2'), 'pic-peer-10-0-0-2/32') def test_different_ip(self): - self.assertEqual(firewall_manager._peer_comment('192.168.1.100'), 'pic-peer-192-168-1-100') + self.assertEqual(firewall_manager._peer_comment('192.168.1.100'), 'pic-peer-192-168-1-100/32') # --------------------------------------------------------------------------- @@ -115,6 +117,87 @@ class TestGenerateCorefile(unittest.TestCase): self.assertFalse(result) +# --------------------------------------------------------------------------- +# generate_corefile with cell_links +# --------------------------------------------------------------------------- + +class TestGenerateCorefileWithCellLinks(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.path = os.path.join(self.tmp, 'Corefile') + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _content(self): + return open(self.path).read() + + def test_cell_links_none_produces_no_forwarding_stanzas(self): + """Default (None) produces no extra forwarding blocks beyond the primary zone.""" + firewall_manager.generate_corefile([], self.path, cell_links=None) + content = self._content() + # The only 'forward' line should be the default internet forwarder + forward_lines = [l for l in content.splitlines() if 'forward' in l] + self.assertEqual(len(forward_lines), 1) + self.assertIn('8.8.8.8', forward_lines[0]) + + def test_cell_links_empty_list_produces_no_extra_stanzas(self): + """An empty cell_links list produces no extra forwarding blocks.""" + firewall_manager.generate_corefile([], self.path, cell_links=[]) + content = self._content() + forward_lines = [l for l in content.splitlines() if 'forward' in l] + self.assertEqual(len(forward_lines), 1) + self.assertIn('8.8.8.8', forward_lines[0]) + + def test_single_cell_link_produces_forwarding_block(self): + """One cell link produces one forwarding stanza with correct domain and dns_ip.""" + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.1.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + self.assertIn('remote.cell {', content) + self.assertIn('forward . 10.1.0.1', content) + self.assertIn('cache', content) + + def test_multiple_cell_links_produce_multiple_forwarding_blocks(self): + """Multiple cell links produce one stanza each.""" + cell_links = [ + {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + self.assertIn('alpha.cell {', content) + self.assertIn('forward . 10.1.0.1', content) + self.assertIn('beta.cell {', content) + self.assertIn('forward . 10.2.0.1', content) + + def test_cell_links_do_not_overwrite_peer_acls(self): + """Cell link stanzas are appended; peer ACLs in the primary zone survive.""" + peers = [_make_peer('10.0.0.3', services=['calendar'])] + cell_links = [{'domain': 'other.cell', 'dns_ip': '10.99.0.1'}] + firewall_manager.generate_corefile(peers, self.path, cell_links=cell_links) + content = self._content() + self.assertIn('block net 10.0.0.3/32', content) + self.assertIn('other.cell {', content) + self.assertIn('forward . 10.99.0.1', content) + + def test_link_with_missing_domain_is_skipped(self): + """A cell_link entry with no domain key is silently skipped.""" + cell_links = [{'dns_ip': '10.1.0.1'}] # no 'domain' + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + # Only the default internet forwarder + forward_lines = [l for l in content.splitlines() if 'forward' in l] + self.assertEqual(len(forward_lines), 1) + + def test_link_with_missing_dns_ip_is_skipped(self): + """A cell_link entry with no dns_ip key is silently skipped.""" + cell_links = [{'domain': 'nope.cell'}] # no 'dns_ip' + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + self.assertNotIn('nope.cell', content) + + # --------------------------------------------------------------------------- # apply_peer_rules — iptables call verification # --------------------------------------------------------------------------- @@ -227,8 +310,8 @@ class TestClearPeerRules(unittest.TestCase): '*filter\n' ':INPUT ACCEPT [0:0]\n' ':FORWARD ACCEPT [0:0]\n' - '-A FORWARD -s 10.0.0.2 -m comment --comment pic-peer-10-0-0-2 -j ACCEPT\n' - '-A FORWARD -s 10.0.0.3 -m comment --comment pic-peer-10-0-0-3 -j DROP\n' + '-A FORWARD -s 10.0.0.2 -m comment --comment "pic-peer-10-0-0-2/32" -j ACCEPT\n' + '-A FORWARD -s 10.0.0.3 -m comment --comment "pic-peer-10-0-0-3/32" -j DROP\n' 'COMMIT\n' ) restored = [] @@ -252,8 +335,8 @@ class TestClearPeerRules(unittest.TestCase): self.assertEqual(len(restored), 1) restored_content = restored[0] - self.assertNotIn('pic-peer-10-0-0-2', restored_content) - self.assertIn('pic-peer-10-0-0-3', restored_content) + self.assertNotIn('pic-peer-10-0-0-2/32', restored_content) + self.assertIn('pic-peer-10-0-0-3/32', restored_content) def test_no_op_when_no_matching_rules(self): save_output = '*filter\n:FORWARD ACCEPT [0:0]\nCOMMIT\n' diff --git a/tests/test_input_validation.py b/tests/test_input_validation.py new file mode 100644 index 0000000..7b19242 --- /dev/null +++ b/tests/test_input_validation.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +""" +Tests for the security input validation on PUT /api/config. + +Validates that domain and cell_name fields reject injection characters +while allowing legitimate values (multi-label domains, hyphens, etc.). +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +def _put(client, payload): + return client.put( + '/api/config', + data=json.dumps(payload), + content_type='application/json', + ) + + +class TestDomainValidation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + def test_domain_with_newline_returns_400(self): + r = _put(self.client, {'domain': 'cell\nnewline'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_with_opening_brace_returns_400(self): + r = _put(self.client, {'domain': 'cell{injection}'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_with_semicolon_returns_400(self): + r = _put(self.client, {'domain': 'cell;rm -rf /'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_with_space_returns_400(self): + r = _put(self.client, {'domain': 'my cell'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_multilabel_with_dot_returns_200(self): + # Multi-label names like 'cell.local' or 'home.lan' must be accepted. + r = _put(self.client, {'domain': 'cell.local'}) + # The endpoint may also return non-400 on 500 if downstream fails, + # but the validation itself must not reject dots. + self.assertNotEqual(r.status_code, 400) + + def test_domain_simple_word_returns_200(self): + r = _put(self.client, {'domain': 'myhome'}) + self.assertNotEqual(r.status_code, 400) + + def test_domain_with_hyphen_returns_200(self): + r = _put(self.client, {'domain': 'my-cell'}) + self.assertNotEqual(r.status_code, 400) + + def test_domain_with_at_sign_returns_400(self): + r = _put(self.client, {'domain': 'cell@evil.com'}) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + def test_domain_with_slash_returns_400(self): + r = _put(self.client, {'domain': 'cell/etc'}) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + +class TestCellNameValidation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + def test_cell_name_with_space_returns_400(self): + r = _put(self.client, {'cell_name': 'my cell'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_with_dot_returns_400(self): + # cell_name is a single hostname component — dots are not allowed + r = _put(self.client, {'cell_name': 'my.cell'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_with_newline_returns_400(self): + r = _put(self.client, {'cell_name': 'cell\nevil'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_with_semicolon_returns_400(self): + r = _put(self.client, {'cell_name': 'cell;drop'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_valid_hyphenated_returns_200(self): + r = _put(self.client, {'cell_name': 'valid-name'}) + self.assertNotEqual(r.status_code, 400) + + def test_cell_name_simple_alpha_returns_200(self): + r = _put(self.client, {'cell_name': 'mycell'}) + self.assertNotEqual(r.status_code, 400) + + def test_cell_name_with_digits_returns_200(self): + r = _put(self.client, {'cell_name': 'cell01'}) + self.assertNotEqual(r.status_code, 400) + + def test_cell_name_with_brace_returns_400(self): + r = _put(self.client, {'cell_name': 'cell{x}'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_is_local_request_per_endpoint.py b/tests/test_is_local_request_per_endpoint.py new file mode 100644 index 0000000..b8b0934 --- /dev/null +++ b/tests/test_is_local_request_per_endpoint.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +""" +Tests verifying that is_local_request() enforcement works correctly +per endpoint in api/app.py. + +The audit flagged that is_local_request() checks are performed inline +(not via a decorator), so this file confirms: + 1. Endpoints that call `is_local_request()` return 403 when the + function returns False (i.e., a non-local caller). + 2. Endpoints that do NOT call `is_local_request()` still respond + normally (2xx / 4xx) for non-local callers. + +Tested local-only endpoints (representative sample): + GET /api/containers — list_containers + POST /api/containers//start + POST /api/containers//stop + POST /api/containers//restart + GET /api/containers//logs + GET /api/containers//stats + GET /api/vault/secrets + POST /api/vault/secrets + GET /api/vault/secrets/ + DELETE /api/vault/secrets/ + GET /api/containers — POST with image field + GET /api/images + POST /api/images/pull + DELETE /api/images/ + GET /api/volumes + POST /api/volumes + DELETE /api/volumes/ + DELETE /api/containers/ + +Tested public endpoints (no is_local_request guard): + GET /api/calendar/status + GET /api/dns/records + GET /api/dhcp/leases + GET /api/cells +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +def _non_local_client(): + """Return a Flask test client that pretends to come from a non-local address.""" + app.config['TESTING'] = True + # Flask's test client uses '127.0.0.1' by default; override with a public IP + # by setting REMOTE_ADDR in the environ base. + return app.test_client() + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _get_non_local(client, path): + """Perform a GET request that appears to originate from a non-local IP.""" + return client.get(path, environ_base={'REMOTE_ADDR': '203.0.113.1'}) + + +def _post_non_local(client, path, body=None): + return client.post( + path, + data=json.dumps(body or {}), + content_type='application/json', + environ_base={'REMOTE_ADDR': '203.0.113.1'}, + ) + + +def _delete_non_local(client, path): + return client.delete(path, environ_base={'REMOTE_ADDR': '203.0.113.1'}) + + +# ── local-only endpoint tests ───────────────────────────────────────────────── + +class TestLocalOnlyEndpointsReturn403ForNonLocal(unittest.TestCase): + """Every endpoint that calls is_local_request() must return 403 for external IPs.""" + + def setUp(self): + app.config['TESTING'] = True + self.client = _non_local_client() + + # Container management + + def test_list_containers_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/containers') + self.assertEqual(r.status_code, 403) + self.assertIn('error', json.loads(r.data)) + + def test_start_container_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/containers/myapp/start') + self.assertEqual(r.status_code, 403) + + def test_stop_container_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/containers/myapp/stop') + self.assertEqual(r.status_code, 403) + + def test_restart_container_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/containers/myapp/restart') + self.assertEqual(r.status_code, 403) + + def test_get_container_logs_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/containers/myapp/logs') + self.assertEqual(r.status_code, 403) + + def test_get_container_stats_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/containers/myapp/stats') + self.assertEqual(r.status_code, 403) + + def test_remove_container_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/containers/myapp') + self.assertEqual(r.status_code, 403) + + # Image management + + def test_list_images_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/images') + self.assertEqual(r.status_code, 403) + + def test_pull_image_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/images/pull', {'image': 'nginx:latest'}) + self.assertEqual(r.status_code, 403) + + def test_remove_image_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/images/nginx') + self.assertEqual(r.status_code, 403) + + # Volume management + + def test_list_volumes_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/volumes') + self.assertEqual(r.status_code, 403) + + def test_create_volume_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/volumes', {'name': 'myvol'}) + self.assertEqual(r.status_code, 403) + + def test_remove_volume_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/volumes/myvol') + self.assertEqual(r.status_code, 403) + + # Vault endpoints + + def test_list_secrets_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/vault/secrets') + self.assertEqual(r.status_code, 403) + + def test_store_secret_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/vault/secrets', {'name': 'k', 'value': 'v'}) + self.assertEqual(r.status_code, 403) + + def test_get_secret_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/vault/secrets/mykey') + self.assertEqual(r.status_code, 403) + + def test_delete_secret_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/vault/secrets/mykey') + self.assertEqual(r.status_code, 403) + + +class TestLocalOnlyEndpointsAllowedFromLocalhost(unittest.TestCase): + """The same endpoints must NOT return 403 for loopback / local callers.""" + + def setUp(self): + app.config['TESTING'] = True + # Default test client remote_addr is 127.0.0.1, which is local + self.client = app.test_client() + + @patch('app.container_manager') + def test_list_containers_allowed_from_local(self, mock_cm): + mock_cm.list_containers.return_value = [] + r = self.client.get('/api/containers') + self.assertNotEqual(r.status_code, 403) + + @patch('app.container_manager') + def test_list_images_allowed_from_local(self, mock_cm): + mock_cm.list_images.return_value = [] + r = self.client.get('/api/images') + self.assertNotEqual(r.status_code, 403) + + @patch('app.container_manager') + def test_list_volumes_allowed_from_local(self, mock_cm): + mock_cm.list_volumes.return_value = [] + r = self.client.get('/api/volumes') + self.assertNotEqual(r.status_code, 403) + + +# ── public endpoint tests — no is_local_request guard ──────────────────────── + +class TestPublicEndpointsNotBlockedForNonLocal(unittest.TestCase): + """ + Endpoints that do NOT call is_local_request() must remain reachable + from non-local addresses. A 403 here would indicate an unintended + broadening of the local-only guard. + """ + + def setUp(self): + app.config['TESTING'] = True + self.client = _non_local_client() + + @patch('app.calendar_manager') + def test_calendar_status_not_403_for_non_local(self, mock_cm): + mock_cm.get_status.return_value = {'running': True} + r = _get_non_local(self.client, '/api/calendar/status') + self.assertNotEqual(r.status_code, 403) + + @patch('app.network_manager') + def test_dns_records_not_403_for_non_local(self, mock_nm): + mock_nm.get_dns_records.return_value = [] + r = _get_non_local(self.client, '/api/dns/records') + self.assertNotEqual(r.status_code, 403) + + @patch('app.network_manager') + def test_dhcp_leases_not_403_for_non_local(self, mock_nm): + mock_nm.get_dhcp_leases.return_value = [] + r = _get_non_local(self.client, '/api/dhcp/leases') + self.assertNotEqual(r.status_code, 403) + + @patch('app.cell_link_manager') + def test_cells_list_not_403_for_non_local(self, mock_clm): + mock_clm.list_connections.return_value = [] + r = _get_non_local(self.client, '/api/cells') + self.assertNotEqual(r.status_code, 403) + + def test_health_check_not_403_for_non_local(self): + r = _get_non_local(self.client, '/health') + self.assertNotEqual(r.status_code, 403) + + +# ── is_local_request logic unit tests ──────────────────────────────────────── + +class TestIsLocalRequestLogic(unittest.TestCase): + """ + Directly verify the is_local_request() function from app.py. + These tests exercise the address-checking logic without going through + a full HTTP request cycle. + """ + + def setUp(self): + from app import is_local_request as _fn + self._fn = _fn + app.config['TESTING'] = True + + def _call_with_addr(self, remote_addr, xff=None): + """Push a fake request context and evaluate is_local_request().""" + from app import app as _app + headers = {} + if xff: + headers['X-Forwarded-For'] = xff + with _app.test_request_context('/', environ_base={'REMOTE_ADDR': remote_addr}, + headers=headers): + return self._fn() + + def test_loopback_127_is_local(self): + self.assertTrue(self._call_with_addr('127.0.0.1')) + + def test_ipv6_loopback_is_local(self): + self.assertTrue(self._call_with_addr('::1')) + + def test_docker_bridge_172_20_is_local(self): + # 172.20.x.x is inside 172.16.0.0/12 + self.assertTrue(self._call_with_addr('172.20.0.5')) + + def test_docker_bridge_172_16_boundary_is_local(self): + # Exact boundary of 172.16.0.0/12 + self.assertTrue(self._call_with_addr('172.16.0.1')) + + def test_public_ip_is_not_local(self): + self.assertFalse(self._call_with_addr('8.8.8.8')) + + def test_wireguard_peer_10_0_0_x_is_not_local(self): + # WireGuard peer IPs (10.0.0.0/8) must NOT be treated as local + self.assertFalse(self._call_with_addr('10.0.0.2')) + + def test_lan_192_168_is_not_local(self): + # LAN addresses must NOT be treated as local (comment in app.py confirms this) + self.assertFalse(self._call_with_addr('192.168.1.50')) + + def test_xff_last_entry_loopback_is_local(self): + # Public remote addr, but last XFF entry is loopback → allowed + self.assertTrue(self._call_with_addr('8.8.8.8', xff='8.8.8.8, 127.0.0.1')) + + def test_xff_first_entry_spoofed_loopback_not_local(self): + # Spoofed first XFF entry; last entry is a public IP → should be rejected + # remote_addr is also public to rule out that shortcut + result = self._call_with_addr('8.8.8.8', xff='127.0.0.1, 8.8.8.8') + self.assertFalse(result) + + def test_xff_last_entry_docker_bridge_is_local(self): + # Last XFF entry is Caddy's Docker bridge address + self.assertTrue(self._call_with_addr('8.8.8.8', xff='1.2.3.4, 172.20.0.2')) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_logs_endpoints.py b/tests/test_logs_endpoints.py new file mode 100644 index 0000000..80d7a7c --- /dev/null +++ b/tests/test_logs_endpoints.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +Unit tests for logs Flask endpoints in api/app.py. + +Covers: + GET /api/logs — backend log file (reads picell.log) + GET /api/logs/services/ — per-service logs via log_manager + POST /api/logs/search — search across services + POST /api/logs/export — export logs + GET /api/logs/statistics — log stats + POST /api/logs/rotate — rotate logs + GET /api/logs/files — list log file info + GET /api/logs/verbosity — get log levels + PUT /api/logs/verbosity — set log levels +""" + +import sys +import json +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock, mock_open + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetBackendLogs(unittest.TestCase): + """GET /api/logs — reads picell.log from api directory.""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + def test_get_logs_returns_404_when_log_file_missing(self): + # Patch os.path.exists so the log file appears absent + with patch('app.os.path.exists', return_value=False): + r = self.client.get('/api/logs') + self.assertEqual(r.status_code, 404) + self.assertIn('error', json.loads(r.data)) + + def test_get_logs_returns_200_with_log_content(self): + log_content = 'INFO 2026-04-27 server started\nERROR something went wrong\n' + m = mock_open(read_data=log_content) + # Bypass auth enforcement by replacing auth_manager with a non-AuthManager object + with patch('app.auth_manager', MagicMock(spec=object)), \ + patch('app.os.path.exists', return_value=True), \ + patch('builtins.open', m): + r = self.client.get('/api/logs') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('log', data) + + def test_get_logs_respects_lines_query_param(self): + # Produce 10 lines; request only last 3 + all_lines = [f'line {i}\n' for i in range(10)] + m = mock_open(read_data=''.join(all_lines)) + m.return_value.__iter__ = lambda s: iter(all_lines) + m.return_value.readlines = lambda: all_lines + # Bypass auth enforcement by replacing auth_manager with a non-AuthManager object + with patch('app.auth_manager', MagicMock(spec=object)), \ + patch('app.os.path.exists', return_value=True), \ + patch('builtins.open', m): + r = self.client.get('/api/logs?lines=3') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + # The tail should contain only the last 3 lines + self.assertIn('line 7', data['log']) + self.assertIn('line 8', data['log']) + self.assertIn('line 9', data['log']) + + def test_get_logs_returns_500_on_exception(self): + with patch('app.os.path.exists', return_value=True), \ + patch('builtins.open', side_effect=PermissionError('access denied')): + r = self.client.get('/api/logs') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetServiceLogs(unittest.TestCase): + """GET /api/logs/services/""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_service_logs_returns_200_with_log_data(self, mock_lm): + mock_lm.get_service_logs.return_value = [ + '[INFO] 2026-04-27 dns started', + '[WARN] 2026-04-27 retry attempt', + ] + r = self.client.get('/api/logs/services/dns') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertEqual(data['service'], 'dns') + self.assertIsInstance(data['logs'], list) + self.assertEqual(len(data['logs']), 2) + + @patch('app.log_manager') + def test_get_service_logs_passes_level_and_lines_params(self, mock_lm): + mock_lm.get_service_logs.return_value = [] + self.client.get('/api/logs/services/email?level=ERROR&lines=25') + mock_lm.get_service_logs.assert_called_once_with('email', 'ERROR', 25) + + @patch('app.log_manager') + def test_get_service_logs_uses_defaults_when_params_absent(self, mock_lm): + mock_lm.get_service_logs.return_value = [] + self.client.get('/api/logs/services/wireguard') + mock_lm.get_service_logs.assert_called_once_with('wireguard', 'INFO', 50) + + @patch('app.log_manager') + def test_get_service_logs_returns_500_on_exception(self, mock_lm): + mock_lm.get_service_logs.side_effect = Exception('log file missing') + r = self.client.get('/api/logs/services/calendar') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestSearchLogs(unittest.TestCase): + """POST /api/logs/search""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_search_logs_returns_200_with_results_and_count(self, mock_lm): + mock_lm.search_logs.return_value = [ + {'service': 'dns', 'line': 'ERROR timeout'}, + ] + r = self.client.post( + '/api/logs/search', + data=json.dumps({'query': 'ERROR', 'services': ['dns']}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('results', data) + self.assertIn('count', data) + self.assertEqual(data['count'], 1) + + @patch('app.log_manager') + def test_search_logs_works_with_empty_body(self, mock_lm): + mock_lm.search_logs.return_value = [] + r = self.client.post('/api/logs/search') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertEqual(data['results'], []) + self.assertEqual(data['count'], 0) + + @patch('app.log_manager') + def test_search_logs_returns_500_on_exception(self, mock_lm): + mock_lm.search_logs.side_effect = Exception('index unavailable') + r = self.client.post( + '/api/logs/search', + data=json.dumps({'query': 'fail'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestExportLogs(unittest.TestCase): + """POST /api/logs/export""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_export_logs_returns_200_with_log_data_and_format(self, mock_lm): + mock_lm.export_logs.return_value = '[{"ts":1,"msg":"ok"}]' + r = self.client.post( + '/api/logs/export', + data=json.dumps({'format': 'json', 'filters': {'service': 'dns'}}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('logs', data) + self.assertIn('format', data) + self.assertEqual(data['format'], 'json') + + @patch('app.log_manager') + def test_export_logs_defaults_to_json_format(self, mock_lm): + mock_lm.export_logs.return_value = '[]' + self.client.post('/api/logs/export') + mock_lm.export_logs.assert_called_once_with('json', {}) + + @patch('app.log_manager') + def test_export_logs_returns_500_on_exception(self, mock_lm): + mock_lm.export_logs.side_effect = Exception('export failed') + r = self.client.post( + '/api/logs/export', + data=json.dumps({'format': 'csv'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetLogStatistics(unittest.TestCase): + """GET /api/logs/statistics""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_statistics_returns_200_with_stats_dict(self, mock_lm): + mock_lm.get_log_statistics.return_value = { + 'total_lines': 1200, + 'error_count': 3, + 'warn_count': 17, + } + r = self.client.get('/api/logs/statistics') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('total_lines', data) + + @patch('app.log_manager') + def test_get_statistics_passes_service_param(self, mock_lm): + mock_lm.get_log_statistics.return_value = {} + self.client.get('/api/logs/statistics?service=email') + mock_lm.get_log_statistics.assert_called_once_with('email') + + @patch('app.log_manager') + def test_get_statistics_passes_none_when_no_service_param(self, mock_lm): + mock_lm.get_log_statistics.return_value = {} + self.client.get('/api/logs/statistics') + mock_lm.get_log_statistics.assert_called_once_with(None) + + @patch('app.log_manager') + def test_get_statistics_returns_500_on_exception(self, mock_lm): + mock_lm.get_log_statistics.side_effect = Exception('stats error') + r = self.client.get('/api/logs/statistics') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestRotateLogs(unittest.TestCase): + """POST /api/logs/rotate""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_rotate_all_logs_returns_200(self, mock_lm): + r = self.client.post('/api/logs/rotate') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + mock_lm.rotate_logs.assert_called_once_with(None) + + @patch('app.log_manager') + def test_rotate_specific_service_passes_service_name(self, mock_lm): + r = self.client.post( + '/api/logs/rotate', + data=json.dumps({'service': 'dns'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + mock_lm.rotate_logs.assert_called_once_with('dns') + + @patch('app.log_manager') + def test_rotate_returns_500_on_exception(self, mock_lm): + mock_lm.rotate_logs.side_effect = Exception('rotate failed') + r = self.client.post('/api/logs/rotate') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetLogFileInfos(unittest.TestCase): + """GET /api/logs/files""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_log_files_returns_200_with_file_list(self, mock_lm): + mock_lm.get_all_log_file_infos.return_value = [ + {'service': 'dns', 'path': '/data/logs/dns.log', 'size_bytes': 4096}, + {'service': 'email', 'path': '/data/logs/email.log', 'size_bytes': 8192}, + ] + r = self.client.get('/api/logs/files') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.log_manager') + def test_get_log_files_returns_500_on_exception(self, mock_lm): + mock_lm.get_all_log_file_infos.side_effect = Exception('filesystem error') + r = self.client.get('/api/logs/files') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestLogVerbosity(unittest.TestCase): + """GET /api/logs/verbosity and PUT /api/logs/verbosity""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_verbosity_returns_200_with_levels_map(self, mock_lm): + mock_lm.get_service_levels.return_value = { + 'dns': 'INFO', + 'email': 'DEBUG', + 'wireguard': 'WARNING', + } + r = self.client.get('/api/logs/verbosity') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('dns', data) + self.assertEqual(data['email'], 'DEBUG') + + @patch('app.log_manager') + def test_get_verbosity_returns_500_on_exception(self, mock_lm): + mock_lm.get_service_levels.side_effect = Exception('config missing') + r = self.client.get('/api/logs/verbosity') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + @patch('app.log_manager') + def test_put_verbosity_returns_200_and_calls_set_level(self, mock_lm): + mock_lm.get_service_levels.return_value = {'dns': 'DEBUG'} + with tempfile.TemporaryDirectory() as tmpdir: + # Endpoint builds: os.path.join(os.path.dirname(__file__), 'config', 'log_levels.json') + # Patch dirname to return tmpdir so the full path becomes tmpdir/config/log_levels.json + config_dir = os.path.join(tmpdir, 'config') + os.makedirs(config_dir) + with patch('app.auth_manager', MagicMock(spec=object)), \ + patch('app.os.path.dirname', return_value=tmpdir): + r = self.client.put( + '/api/logs/verbosity', + data=json.dumps({'dns': 'DEBUG'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + mock_lm.set_service_level.assert_called_with('dns', 'DEBUG') + + @patch('app.log_manager') + def test_put_verbosity_returns_500_on_exception(self, mock_lm): + mock_lm.set_service_level.side_effect = Exception('unknown service') + r = self.client.put( + '/api/logs/verbosity', + data=json.dumps({'unknown_svc': 'DEBUG'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_network_endpoints.py b/tests/test_network_endpoints.py index f2579ef..b5b1c14 100644 --- a/tests/test_network_endpoints.py +++ b/tests/test_network_endpoints.py @@ -1 +1,353 @@ -# ... moved and adapted code from test_phase1_endpoints.py ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for network/DNS/DHCP Flask endpoints in api/app.py. + +Covers: + GET /api/dns/records + POST /api/dns/records + DELETE /api/dns/records + GET /api/dns/status + GET /api/dhcp/leases + POST /api/dhcp/reservations + DELETE /api/dhcp/reservations + GET /api/network/info + POST /api/network/test +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetDnsRecords(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_dns_records_returns_200_with_list(self, mock_nm): + mock_nm.get_dns_records.return_value = [ + {'name': 'myhost.cell', 'type': 'A', 'value': '192.168.1.10'}, + {'name': 'nas.cell', 'type': 'A', 'value': '192.168.1.20'}, + ] + r = self.client.get('/api/dns/records') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.network_manager') + def test_get_dns_records_returns_empty_list_when_none(self, mock_nm): + mock_nm.get_dns_records.return_value = [] + r = self.client.get('/api/dns/records') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.network_manager') + def test_get_dns_records_returns_500_on_exception(self, mock_nm): + mock_nm.get_dns_records.side_effect = Exception('CoreDNS unreachable') + r = self.client.get('/api/dns/records') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddDnsRecord(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_add_dns_record_returns_200_on_valid_body(self, mock_nm): + mock_nm.add_dns_record.return_value = {'success': True} + r = self.client.post( + '/api/dns/records', + data=json.dumps({'name': 'printer.cell', 'type': 'A', 'value': '192.168.1.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('success', data) + + @patch('app.network_manager') + def test_add_dns_record_returns_400_when_no_body(self, mock_nm): + r = self.client.post('/api/dns/records') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_nm.add_dns_record.assert_not_called() + + @patch('app.network_manager') + def test_add_dns_record_returns_500_on_exception(self, mock_nm): + mock_nm.add_dns_record.side_effect = Exception('Corefile write failed') + r = self.client.post( + '/api/dns/records', + data=json.dumps({'name': 'bad.cell', 'type': 'A', 'value': '10.0.0.1'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteDnsRecord(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_delete_dns_record_returns_200_on_success(self, mock_nm): + mock_nm.remove_dns_record.return_value = {'success': True} + r = self.client.delete( + '/api/dns/records', + data=json.dumps({'name': 'printer.cell'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + + @patch('app.network_manager') + def test_delete_dns_record_returns_500_on_exception(self, mock_nm): + mock_nm.remove_dns_record.side_effect = Exception('record not found') + r = self.client.delete( + '/api/dns/records', + data=json.dumps({'name': 'missing.cell'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetDnsStatus(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_dns_status_returns_200_with_status_dict(self, mock_nm): + mock_nm.get_dns_status.return_value = { + 'running': True, + 'records_count': 5, + 'upstreams': ['1.1.1.1', '8.8.8.8'], + } + r = self.client.get('/api/dns/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('running', data) + + @patch('app.network_manager') + def test_get_dns_status_returns_500_on_exception(self, mock_nm): + mock_nm.get_dns_status.side_effect = Exception('CoreDNS not running') + r = self.client.get('/api/dns/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetDhcpLeases(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_dhcp_leases_returns_200_with_list(self, mock_nm): + mock_nm.get_dhcp_leases.return_value = [ + {'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.101', 'hostname': 'laptop'}, + ] + r = self.client.get('/api/dhcp/leases') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 1) + self.assertEqual(data[0]['hostname'], 'laptop') + + @patch('app.network_manager') + def test_get_dhcp_leases_returns_empty_list_when_no_leases(self, mock_nm): + mock_nm.get_dhcp_leases.return_value = [] + r = self.client.get('/api/dhcp/leases') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.network_manager') + def test_get_dhcp_leases_returns_500_on_exception(self, mock_nm): + mock_nm.get_dhcp_leases.side_effect = Exception('dnsmasq not running') + r = self.client.get('/api/dhcp/leases') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddDhcpReservation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_add_reservation_returns_200_on_valid_body(self, mock_nm): + mock_nm.add_dhcp_reservation.return_value = True + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50', 'hostname': 'printer'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('success', data) + + @patch('app.network_manager') + def test_add_reservation_returns_400_when_no_body(self, mock_nm): + r = self.client.post('/api/dhcp/reservations') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_nm.add_dhcp_reservation.assert_not_called() + + @patch('app.network_manager') + def test_add_reservation_returns_400_when_mac_missing(self, mock_nm): + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'ip': '192.168.1.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.network_manager') + def test_add_reservation_returns_400_when_ip_missing(self, mock_nm): + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.network_manager') + def test_add_reservation_uses_empty_hostname_when_omitted(self, mock_nm): + mock_nm.add_dhcp_reservation.return_value = True + self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50'}), + content_type='application/json', + ) + mock_nm.add_dhcp_reservation.assert_called_once_with('aa:bb:cc:dd:ee:ff', '192.168.1.50', '') + + @patch('app.network_manager') + def test_add_reservation_returns_500_on_exception(self, mock_nm): + mock_nm.add_dhcp_reservation.side_effect = Exception('dnsmasq config error') + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteDhcpReservation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_delete_reservation_returns_200_on_success(self, mock_nm): + mock_nm.remove_dhcp_reservation.return_value = True + r = self.client.delete( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('success', data) + + @patch('app.network_manager') + def test_delete_reservation_returns_400_when_mac_missing(self, mock_nm): + r = self.client.delete( + '/api/dhcp/reservations', + data=json.dumps({'hostname': 'printer'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_nm.remove_dhcp_reservation.assert_not_called() + + @patch('app.network_manager') + def test_delete_reservation_returns_400_when_no_body(self, mock_nm): + r = self.client.delete('/api/dhcp/reservations') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.network_manager') + def test_delete_reservation_returns_500_on_exception(self, mock_nm): + mock_nm.remove_dhcp_reservation.side_effect = Exception('reservation not found') + r = self.client.delete( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetNetworkInfo(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_network_info_returns_200_with_info_dict(self, mock_nm): + mock_nm.get_network_info.return_value = { + 'interfaces': ['eth0', 'wg0'], + 'gateway': '192.168.1.1', + 'dns': ['127.0.0.1'], + } + r = self.client.get('/api/network/info') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('interfaces', data) + + @patch('app.network_manager') + def test_get_network_info_returns_500_on_exception(self, mock_nm): + mock_nm.get_network_info.side_effect = Exception('network unreachable') + r = self.client.get('/api/network/info') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestNetworkTest(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_network_test_returns_200_with_result(self, mock_nm): + mock_nm.test_connectivity.return_value = { + 'internet': True, + 'dns': True, + 'latency_ms': 15, + } + r = self.client.post('/api/network/test') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('internet', data) + + @patch('app.network_manager') + def test_network_test_returns_500_on_exception(self, mock_nm): + mock_nm.test_connectivity.side_effect = Exception('ping failed') + r = self.client.post('/api/network/test') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_network_manager.py b/tests/test_network_manager.py index 4cf4e2f..40216a7 100644 --- a/tests/test_network_manager.py +++ b/tests/test_network_manager.py @@ -399,11 +399,13 @@ class TestCellDnsForwarding(unittest.TestCase): self.assertNotIn('10.1.0.1', content) @patch('subprocess.run') - def test_remove_nonexistent_forward_is_noop(self, _mock): - before = open(self.corefile).read() - self.nm.remove_cell_dns_forward('nonexistent.cell') + def test_remove_nonexistent_forward_does_not_error(self, _mock): + # Removing a domain that was never added must not raise and must not + # leave the nonexistent domain in the regenerated Corefile. + result = self.nm.remove_cell_dns_forward('nonexistent.cell') after = open(self.corefile).read() - self.assertEqual(before, after) + self.assertNotIn('nonexistent.cell', after) + # The Corefile is regenerated (new canonical format) — that's correct. if __name__ == '__main__': diff --git a/tests/test_peer_dashboard_services.py b/tests/test_peer_dashboard_services.py new file mode 100644 index 0000000..fc83cd6 --- /dev/null +++ b/tests/test_peer_dashboard_services.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python3 +""" +Unit tests for /api/peer/dashboard and /api/peer/services. + +These tests verify the exact JSON field names and structure returned by +both endpoints so UI/API mismatches surface here before reaching users. + +Coverage: + - peer_dashboard returns name/transfer_rx/transfer_tx (not peer_name/rx_bytes/tx_bytes) + - peer_dashboard includes service_urls dict keyed by service name + - peer_services uses files (not webdav) as the file storage key + - peer_services email block uses nested smtp/imap objects with host/port + - peer_services email.address is the full email address (not username) + - peer_services caldav URL uses calendar.{domain}, not radicale.{domain}:5232 + - peer_services wireguard block includes a config text field with DNS = + - peer_services wireguard DNS is not 10.0.0.1 (WireGuard VPN IP, not CoreDNS) + - Unauthenticated requests return 401; admin sessions return 403 (peer-only zone) + - 404 when session has peer_name but peer not in registry +""" + +import os +import sys +import json +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent / 'api')) + +from app import app +from auth_manager import AuthManager + + +# ─────────────────────────── helpers ────────────────────────────────────────── + +def _make_auth(tmp_path): + data_dir = str(tmp_path / 'data') + cfg_dir = str(tmp_path / 'cfg') + os.makedirs(data_dir, exist_ok=True) + os.makedirs(cfg_dir, exist_ok=True) + mgr = AuthManager(data_dir=data_dir, config_dir=cfg_dir) + mgr.create_user('admin', 'AdminPass123!', 'admin') + mgr.create_user('alice', 'AlicePass123!', 'peer') + return mgr + + +def _login(client, username, password): + return client.post('/api/auth/login', + data=json.dumps({'username': username, 'password': password}), + content_type='application/json') + + +FAKE_PEER = { + 'peer': 'alice', + 'ip': '10.0.0.5', + 'public_key': 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=', + 'private_key': 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB=', + 'allowed_ips': '10.0.0.5/32', + 'internet_access': True, + 'service_access': ['calendar', 'files', 'mail', 'webdav'], + 'active': True, + 'config_needs_reinstall': False, +} + +FAKE_WG_STATS = { + 'online': True, + 'transfer_rx': 1048576, # 1 MiB + 'transfer_tx': 524288, # 512 KiB + 'last_handshake': '2026-04-26T18:00:00', +} + +DOMAIN = 'dev' + + +_REGISTRY_SENTINEL = object() + +def _mock_registry(peer=_REGISTRY_SENTINEL): + reg = MagicMock() + reg.get_peer.return_value = FAKE_PEER if peer is _REGISTRY_SENTINEL else peer + return reg + + +def _mock_wg(dns='172.20.0.3'): + wg = MagicMock() + wg.get_peer_status.return_value = FAKE_WG_STATS + wg.get_keys.return_value = {'public_key': 'SERVERPUBKEY=='} + wg.get_server_config.return_value = {'endpoint': '1.2.3.4:51820'} + wg.FULL_TUNNEL_IPS = '0.0.0.0/0, ::/0' + wg.get_split_tunnel_ips.return_value = '10.0.0.0/24, 172.20.0.0/16' + wg.get_peer_config.return_value = ( + '[Interface]\n' + f'PrivateKey = BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB=\n' + f'Address = 10.0.0.5/32\n' + f'DNS = {dns}\n' + '\n' + '[Peer]\n' + 'PublicKey = SERVERPUBKEY==\n' + 'AllowedIPs = 0.0.0.0/0, ::/0\n' + 'Endpoint = 1.2.3.4:51820\n' + 'PersistentKeepalive = 25\n' + ) + return wg + + +def _mock_config(domain=DOMAIN): + cfg = MagicMock() + cfg.configs = { + '_identity': {'domain': domain, 'cell_name': 'pic0', 'ip_range': '172.20.0.0/16'} + } + return cfg + + +# ─────────────────────────── fixtures ───────────────────────────────────────── + +@pytest.fixture +def auth_mgr(tmp_path): + return _make_auth(tmp_path) + + +@pytest.fixture +def peer_client(auth_mgr): + app.config['TESTING'] = True + app.config['SECRET_KEY'] = 'test-secret' + with patch('app.auth_manager', auth_mgr): + try: + import auth_routes + with patch.object(auth_routes, 'auth_manager', auth_mgr, create=True): + with app.test_client() as c: + r = _login(c, 'alice', 'AlicePass123!') + assert r.status_code == 200, f'peer login failed: {r.data}' + yield c + except (ImportError, AttributeError): + with app.test_client() as c: + r = _login(c, 'alice', 'AlicePass123!') + assert r.status_code == 200, f'peer login failed: {r.data}' + yield c + + +@pytest.fixture +def admin_client(auth_mgr): + app.config['TESTING'] = True + app.config['SECRET_KEY'] = 'test-secret' + with patch('app.auth_manager', auth_mgr): + try: + import auth_routes + with patch.object(auth_routes, 'auth_manager', auth_mgr, create=True): + with app.test_client() as c: + r = _login(c, 'admin', 'AdminPass123!') + assert r.status_code == 200, f'admin login failed: {r.data}' + yield c + except (ImportError, AttributeError): + with app.test_client() as c: + r = _login(c, 'admin', 'AdminPass123!') + assert r.status_code == 200, f'admin login failed: {r.data}' + yield c + + +# ─────────────────── peer_dashboard field names ──────────────────────────────── + +class TestPeerDashboardFieldNames: + """ + peer_dashboard() must return the field names PeerDashboard.jsx reads. + A mismatch causes silent zeros/blanks in the UI without any error. + """ + + def _get(self, peer_client): + wg = _mock_wg() + reg = _mock_registry() + cfg = _mock_config() + with patch('app.peer_registry', reg), \ + patch('app.wireguard_manager', wg), \ + patch('app.config_manager', cfg), \ + patch('app._resolve_peer_dns', return_value='172.20.0.3'): + return peer_client.get('/api/peer/dashboard') + + def test_returns_200(self, peer_client): + r = self._get(peer_client) + assert r.status_code == 200, r.data + + def test_has_name_not_peer_name(self, peer_client): + """PeerDashboard.jsx reads peer.name — must NOT be peer_name.""" + r = self._get(peer_client) + data = r.get_json() + assert 'name' in data, f"'name' missing from dashboard; keys: {list(data)}" + assert 'peer_name' not in data, \ + "'peer_name' still present — UI reads 'name', not 'peer_name'" + + def test_name_value_is_peer_username(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['name'] == 'alice' + + def test_has_transfer_rx_not_rx_bytes(self, peer_client): + """PeerDashboard.jsx reads peer.transfer_rx — must NOT be rx_bytes.""" + r = self._get(peer_client) + data = r.get_json() + assert 'transfer_rx' in data, f"'transfer_rx' missing; keys: {list(data)}" + assert 'rx_bytes' not in data, \ + "'rx_bytes' still present — UI reads 'transfer_rx'" + + def test_has_transfer_tx_not_tx_bytes(self, peer_client): + """PeerDashboard.jsx reads peer.transfer_tx — must NOT be tx_bytes.""" + r = self._get(peer_client) + data = r.get_json() + assert 'transfer_tx' in data, f"'transfer_tx' missing; keys: {list(data)}" + assert 'tx_bytes' not in data, \ + "'tx_bytes' still present — UI reads 'transfer_tx'" + + def test_transfer_rx_value_from_wg_stats(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['transfer_rx'] == 1048576 + + def test_transfer_tx_value_from_wg_stats(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['transfer_tx'] == 524288 + + def test_has_service_urls(self, peer_client): + """Dashboard must include service_urls so UI can render direct service links.""" + r = self._get(peer_client) + data = r.get_json() + assert 'service_urls' in data, f"'service_urls' missing; keys: {list(data)}" + assert isinstance(data['service_urls'], dict) + + def test_service_urls_keyed_by_service(self, peer_client): + r = self._get(peer_client) + urls = r.get_json()['service_urls'] + for svc in ('calendar', 'files', 'mail', 'webdav'): + assert svc in urls, f"service_urls missing '{svc}'; got: {list(urls)}" + + def test_service_urls_use_configured_domain(self, peer_client): + r = self._get(peer_client) + urls = r.get_json()['service_urls'] + assert urls['calendar'] == 'http://calendar.dev' + assert urls['files'] == 'http://files.dev' + assert urls['mail'] == 'http://mail.dev' + assert urls['webdav'] == 'http://webdav.dev' + + def test_online_and_last_handshake_present(self, peer_client): + r = self._get(peer_client) + data = r.get_json() + assert 'online' in data + assert 'last_handshake' in data + + def test_peer_not_in_registry_returns_404(self, peer_client): + reg = _mock_registry(peer=None) + cfg = _mock_config() + with patch('app.peer_registry', reg), patch('app.config_manager', cfg): + r = peer_client.get('/api/peer/dashboard') + assert r.status_code == 404 + + +# ─────────────────── peer_services structure ────────────────────────────────── + +class TestPeerServicesStructure: + """ + peer_services() must return the exact structure MyServices.jsx reads. + All field-name mismatches cause silent blanks in the UI. + """ + + def _get(self, peer_client, dns='172.20.0.3'): + wg = _mock_wg(dns) + reg = _mock_registry() + cfg = _mock_config() + with patch('app.peer_registry', reg), \ + patch('app.wireguard_manager', wg), \ + patch('app.config_manager', cfg), \ + patch('app._resolve_peer_dns', return_value=dns): + return peer_client.get('/api/peer/services') + + def test_returns_200(self, peer_client): + r = self._get(peer_client) + assert r.status_code == 200, r.data + + # -- top-level keys -------------------------------------------------------- + + def test_has_username_at_top_level(self, peer_client): + """MyServices.jsx uses data?.username for the WireGuard config download filename.""" + r = self._get(peer_client) + data = r.get_json() + assert 'username' in data, f"'username' missing from top level; keys: {list(data)}" + assert data['username'] == 'alice' + + def test_has_files_not_webdav(self, peer_client): + """MyServices.jsx reads data?.files — key must be 'files', not 'webdav'.""" + r = self._get(peer_client) + data = r.get_json() + assert 'files' in data, f"'files' key missing; keys: {list(data)}" + assert 'webdav' not in data, \ + "'webdav' key still present — MyServices.jsx reads 'files'" + + def test_has_wireguard_email_caldav_files(self, peer_client): + r = self._get(peer_client) + data = r.get_json() + for section in ('wireguard', 'email', 'caldav', 'files'): + assert section in data, f"'{section}' section missing; keys: {list(data)}" + + # -- email section --------------------------------------------------------- + + def test_email_address_not_username(self, peer_client): + """MyServices.jsx reads email.address — must NOT be email.username.""" + r = self._get(peer_client) + email = r.get_json()['email'] + assert 'address' in email, f"'address' missing from email; keys: {list(email)}" + assert 'username' not in email, \ + "'username' still in email section — MyServices.jsx reads 'address'" + + def test_email_address_is_full_address(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['email']['address'] == 'alice@dev' + + def test_email_has_nested_smtp(self, peer_client): + """MyServices.jsx reads email.smtp.host and email.smtp.port as nested objects.""" + r = self._get(peer_client) + email = r.get_json()['email'] + assert 'smtp' in email, f"'smtp' missing from email; keys: {list(email)}" + smtp = email['smtp'] + assert 'host' in smtp, f"'host' missing from email.smtp; keys: {list(smtp)}" + assert 'port' in smtp, f"'port' missing from email.smtp; keys: {list(smtp)}" + + def test_email_has_nested_imap(self, peer_client): + """MyServices.jsx reads email.imap.host and email.imap.port as nested objects.""" + r = self._get(peer_client) + email = r.get_json()['email'] + assert 'imap' in email, f"'imap' missing from email; keys: {list(email)}" + imap = email['imap'] + assert 'host' in imap, f"'host' missing from email.imap; keys: {list(imap)}" + assert 'port' in imap, f"'port' missing from email.imap; keys: {list(imap)}" + + def test_email_no_flat_host_fields(self, peer_client): + """Flat imap_host/smtp_host fields must not be present.""" + r = self._get(peer_client) + email = r.get_json()['email'] + assert 'imap_host' not in email, \ + "'imap_host' still flat — MyServices.jsx reads email.imap.host" + assert 'smtp_host' not in email, \ + "'smtp_host' still flat — MyServices.jsx reads email.smtp.host" + + def test_email_smtp_host_is_mail_domain(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['email']['smtp']['host'] == 'mail.dev' + + def test_email_imap_host_is_mail_domain(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['email']['imap']['host'] == 'mail.dev' + + def test_email_smtp_port(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['email']['smtp']['port'] == 587 + + def test_email_imap_port(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['email']['imap']['port'] == 993 + + # -- caldav section -------------------------------------------------------- + + def test_caldav_url_uses_calendar_subdomain(self, peer_client): + """CalDAV URL must be http://calendar.{domain}, not radicale.{domain}:5232. + radicale.dev has no DNS record; calendar.dev is the Caddy-proxied entry.""" + r = self._get(peer_client) + url = r.get_json()['caldav']['url'] + assert 'radicale' not in url, \ + f"CalDAV URL contains 'radicale' — no DNS record exists for radicale.dev; got: {url}" + assert ':5232' not in url, \ + f"CalDAV URL exposes internal port 5232 — should use Caddy-proxied URL; got: {url}" + assert url == f'http://calendar.dev', f"CalDAV URL wrong: {url}" + + def test_caldav_username_is_peer_name(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['caldav']['username'] == 'alice' + + # -- files section --------------------------------------------------------- + + def test_files_url_uses_files_subdomain(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['files']['url'] == 'http://files.dev' + + def test_files_username_is_peer_name(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['files']['username'] == 'alice' + + # -- wireguard section ----------------------------------------------------- + + def test_wireguard_dns_is_not_vpn_gateway(self, peer_client): + """DNS must be the CoreDNS container IP, not the WireGuard VPN gateway 10.0.0.1. + 10.0.0.1 is the WireGuard server-side VPN IP which does NOT serve DNS.""" + r = self._get(peer_client) + dns = r.get_json()['wireguard'].get('dns', '') + assert dns != '10.0.0.1', \ + "wireguard.dns is 10.0.0.1 (WireGuard VPN gateway) — should be CoreDNS IP" + + def test_wireguard_dns_is_coredns_ip(self, peer_client): + """DNS must be 172.20.0.3 (the CoreDNS container on the Docker bridge).""" + r = self._get(peer_client) + assert r.get_json()['wireguard']['dns'] == '172.20.0.3' + + def test_wireguard_has_config_field(self, peer_client): + """wg.config field allows peer to download/copy their WireGuard config.""" + r = self._get(peer_client) + wg = r.get_json()['wireguard'] + assert 'config' in wg, f"'config' missing from wireguard section; keys: {list(wg)}" + + def test_wireguard_config_has_dns_line(self, peer_client): + """The config text must contain a DNS = line pointing to CoreDNS.""" + r = self._get(peer_client) + config = r.get_json()['wireguard'].get('config', '') + assert 'DNS = 172.20.0.3' in config, \ + f"Config missing 'DNS = 172.20.0.3'; config:\n{config}" + + def test_wireguard_config_has_interface_section(self, peer_client): + r = self._get(peer_client) + config = r.get_json()['wireguard'].get('config', '') + assert '[Interface]' in config and '[Peer]' in config, \ + f"Config missing [Interface] or [Peer] section; config:\n{config}" + + def test_wireguard_config_has_full_tunnel_allowed_ips(self, peer_client): + """Full-tunnel peers must have AllowedIPs = 0.0.0.0/0 so all traffic goes via VPN.""" + r = self._get(peer_client) + config = r.get_json()['wireguard'].get('config', '') + assert '0.0.0.0/0' in config, \ + f"Config missing 0.0.0.0/0 AllowedIPs for full-tunnel peer; config:\n{config}" + + def test_wireguard_has_ip_field(self, peer_client): + r = self._get(peer_client) + assert 'ip' in r.get_json()['wireguard'] + + def test_wireguard_ip_is_peer_vpn_ip(self, peer_client): + r = self._get(peer_client) + assert r.get_json()['wireguard']['ip'] == '10.0.0.5' + + +# ─────────────────── auth / access control ──────────────────────────────────── + +class TestPeerEndpointAccessControl: + """Peer-only routes must block unauthenticated and admin sessions.""" + + def test_unauthenticated_dashboard_returns_401(self, auth_mgr): + app.config['TESTING'] = True + app.config['SECRET_KEY'] = 'test-secret' + with patch('app.auth_manager', auth_mgr): + with app.test_client() as c: + r = c.get('/api/peer/dashboard') + assert r.status_code == 401 + + def test_unauthenticated_services_returns_401(self, auth_mgr): + app.config['TESTING'] = True + app.config['SECRET_KEY'] = 'test-secret' + with patch('app.auth_manager', auth_mgr): + with app.test_client() as c: + r = c.get('/api/peer/services') + assert r.status_code == 401 + + def test_admin_dashboard_returns_403(self, admin_client): + r = admin_client.get('/api/peer/dashboard') + assert r.status_code == 403, \ + f"Admin accessing peer-only /api/peer/dashboard should get 403, got {r.status_code}" + + def test_admin_services_returns_403(self, admin_client): + r = admin_client.get('/api/peer/services') + assert r.status_code == 403, \ + f"Admin accessing peer-only /api/peer/services should get 403, got {r.status_code}" + + +# ─────────────────── DNS zone records ───────────────────────────────────────── + +class TestDNSZoneRecords: + """ + Verify that network_manager._build_dns_records() generates the correct IPs. + api and webui must point to Caddy (not their container IPs) so Caddy can + reverse-proxy them — their containers don't listen on port 80. + """ + + def setUp(self): + pass + + def _records(self, ip_range='172.20.0.0/16', cell_name='pic0'): + import network_manager as nm + mgr = nm.NetworkManager.__new__(nm.NetworkManager) + return mgr._build_dns_records(cell_name, ip_range) + + def test_api_resolves_to_caddy_not_api_container(self): + records = self._records() + api_rec = next((r for r in records if r['name'] == 'api'), None) + assert api_rec is not None, "No DNS record for 'api'" + assert api_rec['value'] == '172.20.0.2', ( + f"api.dev should resolve to Caddy (172.20.0.2), not the API container " + f"(172.20.0.10); got {api_rec['value']}" + ) + + def test_webui_resolves_to_caddy_not_webui_container(self): + records = self._records() + rec = next((r for r in records if r['name'] == 'webui'), None) + assert rec is not None, "No DNS record for 'webui'" + assert rec['value'] == '172.20.0.2', ( + f"webui.dev should resolve to Caddy (172.20.0.2), not the WebUI container " + f"(172.20.0.11); got {rec['value']}" + ) + + def test_calendar_uses_vip(self): + records = self._records() + rec = next((r for r in records if r['name'] == 'calendar'), None) + assert rec and rec['value'] == '172.20.0.21', \ + f"calendar.dev VIP should be 172.20.0.21; got {rec}" + + def test_files_uses_vip(self): + records = self._records() + rec = next((r for r in records if r['name'] == 'files'), None) + assert rec and rec['value'] == '172.20.0.22', \ + f"files.dev VIP should be 172.20.0.22; got {rec}" + + def test_mail_uses_vip(self): + records = self._records() + rec = next((r for r in records if r['name'] == 'mail'), None) + assert rec and rec['value'] == '172.20.0.23', \ + f"mail.dev VIP should be 172.20.0.23; got {rec}" + + def test_webmail_uses_mail_vip(self): + records = self._records() + rec = next((r for r in records if r['name'] == 'webmail'), None) + assert rec and rec['value'] == '172.20.0.23', \ + f"webmail.dev should share the mail VIP 172.20.0.23; got {rec}" + + def test_webdav_uses_vip(self): + records = self._records() + rec = next((r for r in records if r['name'] == 'webdav'), None) + assert rec and rec['value'] == '172.20.0.24', \ + f"webdav.dev VIP should be 172.20.0.24; got {rec}" + + def test_cell_name_resolves_to_caddy(self): + records = self._records(cell_name='mypic') + rec = next((r for r in records if r['name'] == 'mypic'), None) + assert rec and rec['value'] == '172.20.0.2', \ + f"mypic.dev should resolve to Caddy (172.20.0.2); got {rec}" + + def test_all_records_are_type_a(self): + records = self._records() + for rec in records: + assert rec.get('type') == 'A', f"Record {rec} is not type A" + + +class TestDNSZoneRecordsWithPytest: + """Same as above but using pytest-style (no setUp/tearDown).""" + + @pytest.fixture + def records(self): + import network_manager as nm + mgr = nm.NetworkManager.__new__(nm.NetworkManager) + return mgr._build_dns_records('pic0', '172.20.0.0/16') + + def test_api_resolves_to_caddy(self, records): + rec = next((r for r in records if r['name'] == 'api'), None) + assert rec and rec['value'] == '172.20.0.2', \ + f"api.dev should point to Caddy (172.20.0.2); got {rec}" + + def test_webui_resolves_to_caddy(self, records): + rec = next((r for r in records if r['name'] == 'webui'), None) + assert rec and rec['value'] == '172.20.0.2', \ + f"webui.dev should point to Caddy (172.20.0.2); got {rec}" + + +# ─────────────────── Caddyfile generation ───────────────────────────────────── + +class TestCaddyfileGeneration: + """ + write_caddyfile() must produce a Caddyfile that Caddy can use to route + all service domains including webui.dev. + """ + + @pytest.fixture + def caddyfile(self, tmp_path): + from ip_utils import write_caddyfile + path = str(tmp_path / 'Caddyfile') + write_caddyfile('172.20.0.0/16', 'pic0', 'dev', path) + with open(path) as f: + return f.read() + + def test_main_domain_block_present(self, caddyfile): + assert 'http://pic0.dev' in caddyfile + + def test_api_block_present(self, caddyfile): + assert 'http://api.dev' in caddyfile + + def test_webui_block_present(self, caddyfile): + assert 'http://webui.dev' in caddyfile, \ + "Missing webui.dev Caddy block — webui is unreachable by domain name" + + def test_calendar_block_present(self, caddyfile): + assert 'http://calendar.dev' in caddyfile + + def test_files_block_present(self, caddyfile): + assert 'http://files.dev' in caddyfile + + def test_mail_block_present(self, caddyfile): + assert 'http://mail.dev' in caddyfile + + def test_webdav_block_present(self, caddyfile): + assert 'http://webdav.dev' in caddyfile + + def test_caddy_vips_present(self, caddyfile): + assert '172.20.0.21' in caddyfile, "calendar VIP missing from Caddyfile" + assert '172.20.0.22' in caddyfile, "files VIP missing from Caddyfile" + assert '172.20.0.23' in caddyfile, "mail VIP missing from Caddyfile" + assert '172.20.0.24' in caddyfile, "webdav VIP missing from Caddyfile" + + def test_no_radicale_subdomain_in_caddyfile(self, caddyfile): + """radicale.dev has no DNS record; CalDAV should go via calendar.dev.""" + assert 'radicale.dev' not in caddyfile, \ + "radicale.dev should not appear in Caddyfile — no DNS record for it" + + def test_auto_https_off(self, caddyfile): + assert 'auto_https off' in caddyfile + + def test_reverse_proxy_targets_use_container_names(self, caddyfile): + """Container-internal routing must use service names not IPs.""" + assert 'cell-api:3000' in caddyfile + assert 'cell-radicale:5232' in caddyfile + assert 'cell-webui:80' in caddyfile diff --git a/tests/test_peer_management_edge_cases.py b/tests/test_peer_management_edge_cases.py new file mode 100644 index 0000000..2a9d203 --- /dev/null +++ b/tests/test_peer_management_edge_cases.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Edge-case tests for peer management endpoints in api/app.py. + +Key scenarios: +- POST /api/peers with subnet exhaustion (_next_peer_ip raises ValueError) → 409 +- POST /api/peers//clear-reinstall: success (200) +- POST /api/peers//clear-reinstall: unknown peer raises → 500 +- POST /api/ip-update: missing 'peer' field → 400 +- POST /api/ip-update: missing 'ip' field → 400 +- POST /api/ip-update: unknown peer → 404 +- POST /api/ip-update: success → 200 +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestAddPeerSubnetExhaustion(unittest.TestCase): + """POST /api/peers with no free IPs left must return 409, not 500.""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app._next_peer_ip') + @patch('app.auth_manager') + def test_add_peer_returns_409_when_subnet_exhausted(self, mock_auth, mock_next_ip): + mock_auth.create_user.return_value = True + mock_next_ip.side_effect = ValueError('No free IPs left in 10.0.0.0/24') + + r = self.client.post( + '/api/peers', + data=json.dumps({ + 'name': 'newpeer', + 'public_key': 'PUBKEY==', + 'password': 'verysecret123', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 409) + data = json.loads(r.data) + self.assertIn('error', data) + + @patch('app._next_peer_ip') + @patch('app.auth_manager') + def test_add_peer_409_error_message_mentions_ip(self, mock_auth, mock_next_ip): + mock_auth.create_user.return_value = True + mock_next_ip.side_effect = ValueError('No free IPs left in 10.0.0.0/24') + + r = self.client.post( + '/api/peers', + data=json.dumps({ + 'name': 'newpeer', + 'public_key': 'PUBKEY==', + 'password': 'verysecret123', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 409) + data = json.loads(r.data) + self.assertIn('No free IPs', data['error']) + + +class TestClearReinstallFlag(unittest.TestCase): + """POST /api/peers//clear-reinstall""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.peer_registry') + def test_clear_reinstall_returns_200_on_success(self, mock_reg): + mock_reg.clear_reinstall_flag.return_value = True + r = self.client.post('/api/peers/alice/clear-reinstall') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.peer_registry') + def test_clear_reinstall_calls_registry_with_peer_name(self, mock_reg): + mock_reg.clear_reinstall_flag.return_value = True + self.client.post('/api/peers/bob/clear-reinstall') + mock_reg.clear_reinstall_flag.assert_called_once_with('bob') + + @patch('app.peer_registry') + def test_clear_reinstall_returns_500_when_exception_raised(self, mock_reg): + mock_reg.clear_reinstall_flag.side_effect = Exception('peer not found') + r = self.client.post('/api/peers/ghost/clear-reinstall') + self.assertEqual(r.status_code, 500) + data = json.loads(r.data) + self.assertIn('error', data) + + +class TestIpUpdate(unittest.TestCase): + """POST /api/ip-update""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + @patch('app.peer_registry') + def test_ip_update_returns_200_on_success(self, mock_reg, mock_rm): + mock_reg.update_peer_ip.return_value = True + mock_rm.update_peer_ip.return_value = None + + r = self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'alice', 'ip': '10.0.0.99'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.peer_registry') + def test_ip_update_returns_400_when_peer_field_missing(self, mock_reg): + r = self.client.post( + '/api/ip-update', + data=json.dumps({'ip': '10.0.0.99'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + mock_reg.update_peer_ip.assert_not_called() + + @patch('app.peer_registry') + def test_ip_update_returns_400_when_ip_field_missing(self, mock_reg): + r = self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + mock_reg.update_peer_ip.assert_not_called() + + @patch('app.peer_registry') + def test_ip_update_returns_400_when_no_body(self, mock_reg): + r = self.client.post('/api/ip-update') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.peer_registry') + def test_ip_update_returns_404_when_peer_not_found(self, mock_reg): + mock_reg.update_peer_ip.return_value = False + r = self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'ghost', 'ip': '10.0.0.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 404) + data = json.loads(r.data) + self.assertIn('error', data) + + @patch('app.routing_manager') + @patch('app.peer_registry') + def test_ip_update_calls_registry_with_correct_args(self, mock_reg, mock_rm): + mock_reg.update_peer_ip.return_value = True + mock_rm.update_peer_ip.return_value = None + + self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'alice', 'ip': '10.0.0.5'}), + content_type='application/json', + ) + mock_reg.update_peer_ip.assert_called_once_with('alice', '10.0.0.5') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_peer_management_update.py b/tests/test_peer_management_update.py new file mode 100644 index 0000000..a1ac567 --- /dev/null +++ b/tests/test_peer_management_update.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Tests for PUT /api/peers/. + +Key scenarios: +- 404 when peer_registry.get_peer returns None +- 200 on successful update +- config_needs_reinstall=True in response when internet_access changes +- config_needs_reinstall=False (config_changed=False) when only description changes +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestUpdatePeer(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_returns_404_when_peer_not_found(self, mock_reg, mock_fw): + mock_reg.get_peer.return_value = None + r = self.client.put( + '/api/peers/ghost', + data=json.dumps({'description': 'updated'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 404) + data = json.loads(r.data) + self.assertIn('error', data) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_returns_200_on_success(self, mock_reg, mock_fw): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'my laptop'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_config_changed_true_when_internet_access_changes( + self, mock_reg, mock_fw + ): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'internet_access': False}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertTrue(data['config_changed']) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_config_changed_false_when_only_description_changes( + self, mock_reg, mock_fw + ): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'just a label'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertFalse(data['config_changed']) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_returns_500_when_update_fails(self, mock_reg, mock_fw): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = False + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'fail'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_config_changed_true_when_ip_changes(self, mock_reg, mock_fw): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'ip': '10.0.0.99'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertTrue(data['config_changed']) + + @patch('app.peer_registry') + def test_update_peer_returns_500_on_exception(self, mock_reg): + mock_reg.get_peer.side_effect = Exception('disk error') + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'test'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_routing_endpoints.py b/tests/test_routing_endpoints.py index cd725ab..1c55d32 100644 --- a/tests/test_routing_endpoints.py +++ b/tests/test_routing_endpoints.py @@ -1 +1,294 @@ -# ... moved and adapted code from test_phase4_endpoints.py ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for routing Flask endpoints in api/app.py. + +Covers: + POST /api/routing/peers (peer_name + peer_ip required) + POST /api/routing/exit-nodes (peer_name + peer_ip required) + POST /api/routing/bridge (source_peer + target_peer required) + POST /api/routing/split (network + exit_peer required) + GET /api/routing/peers + DELETE /api/routing/peers/ +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestAddPeerRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_peer_route_returns_200_on_success(self, mock_rm): + mock_rm.add_peer_route.return_value = True + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_name': 'alice', 'peer_ip': '10.0.0.2'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_peer_route_returns_400_when_peer_name_missing(self, mock_rm): + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_ip': '10.0.0.2'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_peer_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_peer_route_returns_400_when_peer_ip_missing(self, mock_rm): + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_name': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_peer_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_peer_route_returns_500_on_exception(self, mock_rm): + mock_rm.add_peer_route.side_effect = Exception('iptables error') + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_name': 'alice', 'peer_ip': '10.0.0.2'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddExitNode(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_exit_node_returns_200_on_success(self, mock_rm): + mock_rm.add_exit_node.return_value = True + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_name': 'gw', 'peer_ip': '10.0.0.5'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_exit_node_returns_400_when_peer_name_missing(self, mock_rm): + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_ip': '10.0.0.5'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_exit_node.assert_not_called() + + @patch('app.routing_manager') + def test_add_exit_node_returns_400_when_peer_ip_missing(self, mock_rm): + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_name': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_exit_node.assert_not_called() + + @patch('app.routing_manager') + def test_add_exit_node_returns_500_on_exception(self, mock_rm): + mock_rm.add_exit_node.side_effect = Exception('routing table full') + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_name': 'gw', 'peer_ip': '10.0.0.5'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddBridgeRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_bridge_returns_200_on_success(self, mock_rm): + mock_rm.add_bridge_route.return_value = True + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'source_peer': 'alice', 'target_peer': 'bob'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_bridge_returns_400_when_source_peer_missing(self, mock_rm): + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'target_peer': 'bob'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_bridge_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_bridge_returns_400_when_target_peer_missing(self, mock_rm): + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'source_peer': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_bridge_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_bridge_returns_500_on_exception(self, mock_rm): + mock_rm.add_bridge_route.side_effect = Exception('bridge setup failed') + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'source_peer': 'alice', 'target_peer': 'bob'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddSplitRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_split_returns_200_on_success(self, mock_rm): + mock_rm.add_split_route.return_value = True + r = self.client.post( + '/api/routing/split', + data=json.dumps({'network': '192.168.10.0/24', 'exit_peer': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_split_returns_400_when_network_missing(self, mock_rm): + r = self.client.post( + '/api/routing/split', + data=json.dumps({'exit_peer': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_split_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_split_returns_400_when_exit_peer_missing(self, mock_rm): + r = self.client.post( + '/api/routing/split', + data=json.dumps({'network': '192.168.10.0/24'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_split_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_split_returns_500_on_exception(self, mock_rm): + mock_rm.add_split_route.side_effect = Exception('split tunnel error') + r = self.client.post( + '/api/routing/split', + data=json.dumps({'network': '192.168.10.0/24', 'exit_peer': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetPeerRoutes(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_get_peer_routes_returns_200_with_routes(self, mock_rm): + mock_rm.get_peer_routes.return_value = [ + {'peer_name': 'alice', 'peer_ip': '10.0.0.2', 'route_type': 'lan'}, + ] + r = self.client.get('/api/routing/peers') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('peer_routes', data) + self.assertIsInstance(data['peer_routes'], list) + + @patch('app.routing_manager') + def test_get_peer_routes_returns_empty_list_when_no_routes(self, mock_rm): + mock_rm.get_peer_routes.return_value = [] + r = self.client.get('/api/routing/peers') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertEqual(data['peer_routes'], []) + + @patch('app.routing_manager') + def test_get_peer_routes_returns_500_on_exception(self, mock_rm): + mock_rm.get_peer_routes.side_effect = Exception('DB error') + r = self.client.get('/api/routing/peers') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeletePeerRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_delete_peer_route_returns_200_on_success(self, mock_rm): + mock_rm.remove_peer_route.return_value = {'removed': True} + r = self.client.delete('/api/routing/peers/alice') + self.assertEqual(r.status_code, 200) + + @patch('app.routing_manager') + def test_delete_peer_route_calls_manager_with_name(self, mock_rm): + mock_rm.remove_peer_route.return_value = {'removed': True} + self.client.delete('/api/routing/peers/bob') + mock_rm.remove_peer_route.assert_called_once_with('bob') + + @patch('app.routing_manager') + def test_delete_peer_route_returns_500_on_exception(self, mock_rm): + mock_rm.remove_peer_route.side_effect = Exception('iptables flush error') + r = self.client.delete('/api/routing/peers/alice') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_wireguard_vpn_routing.py b/tests/test_wireguard_vpn_routing.py new file mode 100644 index 0000000..777deb2 --- /dev/null +++ b/tests/test_wireguard_vpn_routing.py @@ -0,0 +1,539 @@ +#!/usr/bin/env python3 +""" +Tests for WireGuard VPN routing: internet access and DNS resolution through tunnel. + +Scenarios covered: + 1. generate_config() produces PostUp/PostDown rules that enable internet forwarding + (MASQUERADE + FORWARD ACCEPT are the two iptables rules that make "internet + through VPN" work — without them, packets from 10.0.0.x are not NATted to eth0). + 2. get_peer_config() sets DNS = so clients resolve domain names + through the PIC DNS container, not their local ISP resolver. + 3. apply_config() bootstrap path (empty wg0.conf) restores all active peers from + peers.json so clients can reconnect after an API restart that regenerated the file. + 4. _load_registered_peers() correctly filters peers.json. + 5. add_peer() writes a /32 AllowedIPs entry so routing targets only that client. +""" + +import sys +import os +import json +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from wireguard_manager import WireGuardManager, _resolve_peer_dns + + +# A syntactically-valid WireGuard base64 public key (44 chars, ends with =). +FAKE_PUBKEY = 'O35JY6nc8sb9zEarZYZVl70jno/J9dRyiB37YSYy4nA=' +FAKE_PUBKEY2 = 'AbCdEfGhIjKlMnOpQrStUvWxYz0123456789ABCDEFG=' + + +def _make_wg(tmp: str) -> WireGuardManager: + """Build a WireGuardManager rooted in *tmp*, with _syncconf disabled.""" + with patch.object(WireGuardManager, '_syncconf', return_value=None): + wg = WireGuardManager(tmp, tmp) + return wg + + +# ── 1. Internet forwarding rules in generate_config() ───────────────────────── + +class TestInternetForwardingRules(unittest.TestCase): + """ + Verify that generate_config() emits the exact iptables rules required for + 'internet through VPN': MASQUERADE on eth0 (outbound NAT) and FORWARD ACCEPT + on the wg0 interface. Missing either rule means VPN clients get no internet. + """ + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + with patch.object(WireGuardManager, '_syncconf', return_value=None): + self.wg = WireGuardManager(self.tmp, self.tmp) + + def test_postup_has_masquerade_on_eth0(self): + """MASQUERADE on eth0 NATs VPN-subnet packets so internet routers see the host IP.""" + cfg = self.wg.generate_config() + self.assertIn('POSTROUTING -o eth0 -j MASQUERADE', cfg) + + def test_postup_has_forward_accept_on_wg_interface(self): + """FORWARD ACCEPT allows packets from the WireGuard interface through the kernel.""" + cfg = self.wg.generate_config() + self.assertIn('FORWARD -i %i -j ACCEPT', cfg) + + def test_postdown_removes_masquerade_rule(self): + """PostDown must mirror PostUp so rules are cleaned up when the tunnel goes down.""" + cfg = self.wg.generate_config() + self.assertIn('POSTROUTING -o eth0 -j MASQUERADE', cfg.split('PostDown')[1]) + + def test_postdown_removes_forward_rule(self): + cfg = self.wg.generate_config() + self.assertIn('FORWARD -i %i -j ACCEPT', cfg.split('PostDown')[1]) + + def test_postup_and_postdown_are_present(self): + """Both PostUp and PostDown must exist — PostUp without PostDown leaks rules.""" + cfg = self.wg.generate_config() + self.assertIn('PostUp', cfg) + self.assertIn('PostDown', cfg) + + def test_masquerade_is_in_postup_not_only_postdown(self): + """MASQUERADE must appear in PostUp (adding the rule), not only PostDown.""" + cfg = self.wg.generate_config() + postup_section = cfg.split('PostUp')[1].split('PostDown')[0] + self.assertIn('MASQUERADE', postup_section) + + +# ── 2. DNS resolution: get_peer_config() sets DNS field ─────────────────────── + +class TestPeerConfigDns(unittest.TestCase): + """ + Verify that peer client configs include a DNS = line pointing to the + PIC DNS container. Without DNS, the client tunnel has no internet-accessible + domain resolution even though packets are forwarded correctly. + """ + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + with patch.object(WireGuardManager, '_syncconf', return_value=None): + self.wg = WireGuardManager(self.tmp, self.tmp) + + def test_peer_config_contains_dns_line(self): + keys = self.wg.generate_peer_keys('testpeer') + cfg = self.wg.get_peer_config('testpeer', '10.0.0.2', keys['private_key']) + self.assertIn('DNS =', cfg) + + def test_peer_config_dns_is_valid_ip(self): + import ipaddress + keys = self.wg.generate_peer_keys('testpeer') + cfg = self.wg.get_peer_config('testpeer', '10.0.0.2', keys['private_key']) + dns_line = next(l for l in cfg.splitlines() if l.startswith('DNS =')) + dns_ip = dns_line.split('=', 1)[1].strip() + # Must be a parseable IPv4 address + ipaddress.IPv4Address(dns_ip) + + def test_peer_config_dns_defaults_to_cell_dns_ip(self): + """When cell-dns hostname can't be resolved, falls back to 172.20.0.3.""" + with patch('wireguard_manager.socket.gethostbyname', side_effect=OSError): + keys = self.wg.generate_peer_keys('p1') + cfg = self.wg.get_peer_config('p1', '10.0.0.5', keys['private_key']) + self.assertIn('DNS = 172.20.0.3', cfg) + + def test_peer_config_dns_uses_resolved_hostname(self): + """When cell-dns resolves, its IP is used as the DNS server.""" + with patch('wireguard_manager.socket.gethostbyname', return_value='172.20.0.3'): + keys = self.wg.generate_peer_keys('p2') + cfg = self.wg.get_peer_config('p2', '10.0.0.6', keys['private_key']) + self.assertIn('DNS = 172.20.0.3', cfg) + + def test_resolve_peer_dns_fallback(self): + """_resolve_peer_dns() always returns a string even when DNS lookup fails.""" + with patch('wireguard_manager.socket.gethostbyname', side_effect=OSError): + result = _resolve_peer_dns() + self.assertIsInstance(result, str) + self.assertEqual(result, '172.20.0.3') + + def test_peer_config_allowed_ips_default_full_tunnel(self): + """Default AllowedIPs = 0.0.0.0/0 routes all traffic (including internet) through VPN.""" + keys = self.wg.generate_peer_keys('p3') + cfg = self.wg.get_peer_config('p3', '10.0.0.7', keys['private_key']) + # Full tunnel: 0.0.0.0/0 means all traffic goes through the VPN + self.assertIn('0.0.0.0/0', cfg) + + +# ── 3. Bootstrap restores peers from peers.json ─────────────────────────────── + +class TestApplyConfigBootstrapRestoresPeers(unittest.TestCase): + """ + apply_config() is called when the WireGuard port changes. If wg0.conf is + empty or missing [Interface], it bootstraps from generate_config() — which + only generates the [Interface] section and loses all [Peer] blocks. + + The fix: after bootstrap, load active peers from peers.json and restore their + [Peer] blocks so clients can reconnect without manual intervention. + """ + + def _make_wg_with_conf(self, conf_content: str = '') -> tuple: + tmp = tempfile.mkdtemp() + with patch.object(WireGuardManager, '_syncconf', return_value=None): + wg = WireGuardManager(tmp, tmp) + + # Ensure wg_confs/ dir and write the file + cf = wg._config_file() + os.makedirs(os.path.dirname(cf), exist_ok=True) + with open(cf, 'w') as f: + f.write(conf_content) + return wg, cf, tmp + + def _write_peers_json(self, wg: WireGuardManager, peers: list): + peers_file = os.path.join(wg.data_dir, 'peers.json') + with open(peers_file, 'w') as f: + json.dump(peers, f) + + def tearDown(self): + pass # each test manages its own tmp + + def test_empty_conf_triggers_bootstrap(self): + wg, cf, tmp = self._make_wg_with_conf('') + try: + self._write_peers_json(wg, []) + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + result = wg.apply_config({'port': 51820}) + self.assertIn('wg0.conf was empty — regenerated from keys', result['warnings']) + finally: + shutil.rmtree(tmp) + + def test_bootstrap_restores_active_peer(self): + """After bootstrap on empty conf, active peer from peers.json appears in wg0.conf.""" + wg, cf, tmp = self._make_wg_with_conf('') + try: + self._write_peers_json(wg, [{ + 'peer': 'user1', + 'ip': '10.0.0.2', + 'public_key': FAKE_PUBKEY, + 'active': True, + }]) + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + wg.apply_config({'port': 51820}) + content = open(cf).read() + self.assertIn('[Peer]', content) + self.assertIn(FAKE_PUBKEY, content) + self.assertIn('AllowedIPs = 10.0.0.2/32', content) + finally: + shutil.rmtree(tmp) + + def test_bootstrap_restores_multiple_peers(self): + wg, cf, tmp = self._make_wg_with_conf('') + try: + self._write_peers_json(wg, [ + {'peer': 'peer1', 'ip': '10.0.0.2', 'public_key': FAKE_PUBKEY, 'active': True}, + {'peer': 'peer2', 'ip': '10.0.0.3', 'public_key': FAKE_PUBKEY2, 'active': True}, + ]) + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + wg.apply_config({'port': 51820}) + content = open(cf).read() + self.assertIn(FAKE_PUBKEY, content) + self.assertIn(FAKE_PUBKEY2, content) + self.assertEqual(content.count('[Peer]'), 2) + finally: + shutil.rmtree(tmp) + + def test_bootstrap_skips_inactive_peers(self): + """Inactive peers (active=False) must NOT be restored to wg0.conf.""" + wg, cf, tmp = self._make_wg_with_conf('') + try: + self._write_peers_json(wg, [ + {'peer': 'active', 'ip': '10.0.0.2', 'public_key': FAKE_PUBKEY, 'active': True}, + {'peer': 'inactive', 'ip': '10.0.0.3', 'public_key': FAKE_PUBKEY2, 'active': False}, + ]) + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + wg.apply_config({'port': 51820}) + content = open(cf).read() + self.assertIn(FAKE_PUBKEY, content) + self.assertNotIn(FAKE_PUBKEY2, content) + finally: + shutil.rmtree(tmp) + + def test_bootstrap_skips_peer_missing_public_key(self): + wg, cf, tmp = self._make_wg_with_conf('') + try: + self._write_peers_json(wg, [ + {'peer': 'nok', 'ip': '10.0.0.2', 'active': True}, # no public_key + ]) + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + wg.apply_config({'port': 51820}) + content = open(cf).read() + self.assertEqual(content.count('[Peer]'), 0) + finally: + shutil.rmtree(tmp) + + def test_bootstrap_skips_peer_missing_ip(self): + wg, cf, tmp = self._make_wg_with_conf('') + try: + self._write_peers_json(wg, [ + {'peer': 'nok', 'public_key': FAKE_PUBKEY, 'active': True}, # no ip + ]) + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + wg.apply_config({'port': 51820}) + content = open(cf).read() + self.assertNotIn(FAKE_PUBKEY, content) + finally: + shutil.rmtree(tmp) + + def test_existing_conf_with_interface_not_bootstrapped(self): + """If [Interface] is present, bootstrap must NOT run — existing peers are preserved.""" + wg, cf, tmp = self._make_wg_with_conf( + '[Interface]\nListenPort = 51820\nPrivateKey = dummykey\n' + '\n[Peer]\n# existing\nPublicKey = ' + FAKE_PUBKEY + '\nAllowedIPs = 10.0.0.2/32\n' + ) + try: + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + result = wg.apply_config({'port': 51821}) + self.assertNotIn('wg0.conf was empty — regenerated from keys', result['warnings']) + # Original peer must still be there after port-only change + self.assertIn(FAKE_PUBKEY, open(cf).read()) + finally: + shutil.rmtree(tmp) + + def test_restored_peers_have_slash32_allowed_ips(self): + """/32 is mandatory: a wider mask would route internet traffic to the wrong peer.""" + wg, cf, tmp = self._make_wg_with_conf('') + try: + self._write_peers_json(wg, [ + {'peer': 'user1', 'ip': '10.0.0.2', 'public_key': FAKE_PUBKEY, 'active': True}, + ]) + with patch.object(wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + wg.apply_config({'port': 51820}) + content = open(cf).read() + # Must be /32, not /24 or /0 + self.assertIn('AllowedIPs = 10.0.0.2/32', content) + self.assertNotIn('AllowedIPs = 10.0.0.2/24', content) + finally: + shutil.rmtree(tmp) + + +# ── 4. _load_registered_peers() ─────────────────────────────────────────────── + +class TestLoadRegisteredPeers(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + with patch.object(WireGuardManager, '_syncconf', return_value=None): + self.wg = WireGuardManager(self.tmp, self.tmp) + + def _write_peers(self, peers: list): + path = os.path.join(self.wg.data_dir, 'peers.json') + with open(path, 'w') as f: + json.dump(peers, f) + + def test_returns_empty_list_when_file_missing(self): + self.assertEqual(self.wg._load_registered_peers(), []) + + def test_returns_empty_list_on_malformed_json(self): + path = os.path.join(self.wg.data_dir, 'peers.json') + with open(path, 'w') as f: + f.write('not json {{{') + self.assertEqual(self.wg._load_registered_peers(), []) + + def test_returns_active_peers(self): + self._write_peers([ + {'peer': 'u1', 'ip': '10.0.0.2', 'public_key': FAKE_PUBKEY, 'active': True}, + ]) + result = self.wg._load_registered_peers() + self.assertEqual(len(result), 1) + self.assertEqual(result[0]['public_key'], FAKE_PUBKEY) + + def test_filters_out_inactive_peers(self): + self._write_peers([ + {'peer': 'u1', 'ip': '10.0.0.2', 'public_key': FAKE_PUBKEY, 'active': True}, + {'peer': 'u2', 'ip': '10.0.0.3', 'public_key': FAKE_PUBKEY2, 'active': False}, + ]) + result = self.wg._load_registered_peers() + self.assertEqual(len(result), 1) + self.assertEqual(result[0]['public_key'], FAKE_PUBKEY) + + def test_filters_out_peers_without_public_key(self): + self._write_peers([ + {'peer': 'u1', 'ip': '10.0.0.2', 'active': True}, + ]) + self.assertEqual(self.wg._load_registered_peers(), []) + + def test_filters_out_peers_without_ip(self): + self._write_peers([ + {'peer': 'u1', 'public_key': FAKE_PUBKEY, 'active': True}, + ]) + self.assertEqual(self.wg._load_registered_peers(), []) + + def test_treats_missing_active_field_as_active(self): + """Peers without 'active' key should be treated as active (default True).""" + self._write_peers([ + {'peer': 'u1', 'ip': '10.0.0.2', 'public_key': FAKE_PUBKEY}, + ]) + result = self.wg._load_registered_peers() + self.assertEqual(len(result), 1) + + def test_skips_non_dict_entries(self): + self._write_peers([ + 'not_a_dict', + {'peer': 'u1', 'ip': '10.0.0.2', 'public_key': FAKE_PUBKEY, 'active': True}, + ]) + result = self.wg._load_registered_peers() + self.assertEqual(len(result), 1) + + def test_returns_all_required_fields(self): + self._write_peers([ + {'peer': 'u1', 'ip': '10.0.0.5', 'public_key': FAKE_PUBKEY, 'active': True}, + ]) + result = self.wg._load_registered_peers() + self.assertIn('ip', result[0]) + self.assertIn('public_key', result[0]) + + +# ── 5. add_peer() writes correct server-side AllowedIPs ─────────────────────── + +class TestAddPeerServerSideAllowedIps(unittest.TestCase): + """ + Server-side AllowedIPs must be a /32 host address matching the peer's VPN IP. + Wider masks (e.g. 0.0.0.0/0) would route internet traffic from all other + clients to that single peer, breaking internet access for everyone else. + """ + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + with patch.object(WireGuardManager, '_syncconf', return_value=None): + self.wg = WireGuardManager(self.tmp, self.tmp) + + def test_add_peer_writes_slash32_allowed_ips(self): + ok = self.wg.add_peer('peer1', FAKE_PUBKEY, '', '10.0.0.2/32') + self.assertTrue(ok) + content = open(self.wg._config_file()).read() + self.assertIn('AllowedIPs = 10.0.0.2/32', content) + + def test_add_peer_rejects_full_tunnel_allowed_ips(self): + """0.0.0.0/0 as server AllowedIPs is invalid and must be rejected.""" + ok = self.wg.add_peer('peer1', FAKE_PUBKEY, '', '0.0.0.0/0') + self.assertFalse(ok) + + def test_add_peer_rejects_subnet_allowed_ips(self): + """10.0.0.0/24 as server AllowedIPs is invalid and must be rejected.""" + ok = self.wg.add_peer('peer1', FAKE_PUBKEY, '', '10.0.0.0/24') + self.assertFalse(ok) + + def test_add_peer_does_not_write_peer_on_rejection(self): + # Add a valid peer first so the conf file exists, then attempt bad add + self.wg.add_peer('valid', FAKE_PUBKEY2, '', '10.0.0.99/32') + ok = self.wg.add_peer('peer1', FAKE_PUBKEY, '', '0.0.0.0/0') + self.assertFalse(ok) + content = open(self.wg._config_file()).read() + # The bad peer's key must not appear; the valid one may + self.assertNotIn(FAKE_PUBKEY, content) + + def test_add_peer_writes_public_key(self): + self.wg.add_peer('peer1', FAKE_PUBKEY, '', '10.0.0.2/32') + content = open(self.wg._config_file()).read() + self.assertIn(f'PublicKey = {FAKE_PUBKEY}', content) + + def test_add_peer_writes_peer_name_as_comment(self): + self.wg.add_peer('user1', FAKE_PUBKEY, '', '10.0.0.2/32') + content = open(self.wg._config_file()).read() + self.assertIn('# user1', content) + + def test_add_peer_writes_persistent_keepalive(self): + self.wg.add_peer('peer1', FAKE_PUBKEY, '', '10.0.0.2/32', 25) + content = open(self.wg._config_file()).read() + self.assertIn('PersistentKeepalive = 25', content) + + +# ── 6. Key sync: _sync_keys_from_conf() ────────────────────────────────────── + +class TestSyncKeysFromConf(unittest.TestCase): + """ + linuxserver/wireguard auto-generates its own PrivateKey on first container start. + The PIC API generates a separate key independently. _sync_keys_from_conf() must + detect the mismatch and update the API key-store so get_peer_config() embeds + the correct server public key — otherwise the WireGuard handshake fails silently. + """ + + def _make_wg(self, tmp: str) -> WireGuardManager: + with patch.object(WireGuardManager, '_syncconf', return_value=None): + return WireGuardManager(tmp, tmp) + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + with patch.object(WireGuardManager, '_syncconf', return_value=None): + self.wg = WireGuardManager(self.tmp, self.tmp) + + def _write_conf_with_key(self, priv_b64: str): + """Write a minimal wg0.conf with the given PrivateKey.""" + cf = self.wg._config_file() + os.makedirs(os.path.dirname(cf), exist_ok=True) + with open(cf, 'w') as f: + f.write(f'[Interface]\nPrivateKey = {priv_b64}\nListenPort = 51820\nAddress = 10.0.0.1/24\n') + + def _generate_key_pair(self): + from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey + import base64 as _b64 + priv = X25519PrivateKey.generate() + priv_bytes = priv.private_bytes_raw() + pub_bytes = priv.public_key().public_bytes_raw() + return _b64.b64encode(priv_bytes).decode(), _b64.b64encode(pub_bytes).decode() + + def test_sync_updates_api_key_when_conf_differs(self): + """When wg0.conf has a different PrivateKey, the API key-store must be updated.""" + new_priv, new_pub = self._generate_key_pair() + self._write_conf_with_key(new_priv) + self.wg._sync_keys_from_conf() + api_keys = self.wg.get_keys() + self.assertEqual(api_keys['private_key'], new_priv) + self.assertEqual(api_keys['public_key'], new_pub) + + def test_sync_no_op_when_keys_match(self): + """If wg0.conf already has the same key as the API store, nothing changes.""" + api_keys = self.wg.get_keys() + self._write_conf_with_key(api_keys['private_key']) + self.wg._sync_keys_from_conf() # should not raise or change anything + after = self.wg.get_keys() + self.assertEqual(api_keys['public_key'], after['public_key']) + + def test_sync_makes_get_peer_config_use_correct_server_pubkey(self): + """After sync, get_peer_config() must embed the updated server public key.""" + new_priv, new_pub = self._generate_key_pair() + self._write_conf_with_key(new_priv) + self.wg._sync_keys_from_conf() + peer_keys = self.wg.generate_peer_keys('testpeer') + cfg = self.wg.get_peer_config('testpeer', '10.0.0.2', peer_keys['private_key']) + self.assertIn(new_pub, cfg) + + def test_sync_is_noop_when_conf_missing(self): + """_sync_keys_from_conf() must not raise when wg0.conf doesn't exist.""" + # Don't create the conf file + self.wg._sync_keys_from_conf() # should not raise + + def test_apply_config_calls_sync_before_bootstrap(self): + """apply_config() must call _sync_keys_from_conf() so bootstrap uses the live key.""" + new_priv, new_pub = self._generate_key_pair() + cf = self.wg._config_file() + os.makedirs(os.path.dirname(cf), exist_ok=True) + with open(cf, 'w') as f: + f.write('') # empty conf triggers bootstrap + # Write the "new" key to the API key store as if the container auto-generated it + import base64 as _b64 + from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey + priv_obj = X25519PrivateKey.from_private_bytes(_b64.b64decode(new_priv)) + priv_bytes = priv_obj.private_bytes_raw() + pub_bytes = priv_obj.public_key().public_bytes_raw() + with open(os.path.join(self.wg.keys_dir, 'private.key'), 'wb') as f: + f.write(priv_bytes) + with open(os.path.join(self.wg.keys_dir, 'public.key'), 'wb') as f: + f.write(pub_bytes) + + peers_file = os.path.join(self.wg.data_dir, 'peers.json') + with open(peers_file, 'w') as f: + json.dump([], f) + + with patch.object(self.wg, 'get_external_ip', return_value=None), \ + patch('subprocess.run'): + self.wg.apply_config({'port': 51820}) + + content = open(cf).read() + self.assertIn(new_priv, content) + + +if __name__ == '__main__': + unittest.main() diff --git a/webui/src/pages/Dashboard.jsx b/webui/src/pages/Dashboard.jsx index 77da25d..f626106 100644 --- a/webui/src/pages/Dashboard.jsx +++ b/webui/src/pages/Dashboard.jsx @@ -24,9 +24,9 @@ function Dashboard({ isOnline }) { const { domain = 'cell', cell_name = 'mycell' } = useConfig(); const SERVICES = [ { name: 'Cell Home', url: `http://${cell_name}.${domain}`, desc: 'Main UI — no login needed' }, - { name: 'Calendar', url: `http://calendar.${domain}`, desc: 'Login: your WireGuard username' }, - { name: 'Files', url: `http://files.${domain}`, desc: 'Login: admin / admin123' }, - { name: 'Webmail', url: `http://mail.${domain}`, desc: 'Login: admin@rainloop.net / 12345' }, + { name: 'Calendar', url: `http://calendar.${domain}`, desc: 'Use your configured account credentials' }, + { name: 'Files', url: `http://files.${domain}`, desc: 'Use your configured account credentials' }, + { name: 'Webmail', url: `http://mail.${domain}`, desc: 'Use your configured account credentials' }, ]; const [cellStatus, setCellStatus] = useState(null); const [servicesStatus, setServicesStatus] = useState(null); diff --git a/webui/src/pages/PeerDashboard.jsx b/webui/src/pages/PeerDashboard.jsx index 87b0ee3..9a89bae 100644 --- a/webui/src/pages/PeerDashboard.jsx +++ b/webui/src/pages/PeerDashboard.jsx @@ -1,6 +1,6 @@ import React, { useState, useEffect } from 'react'; import { Link } from 'react-router-dom'; -import { Wifi, ArrowDown, ArrowUp, Clock } from 'lucide-react'; +import { Wifi, ArrowDown, ArrowUp, Clock, Calendar, FolderOpen, Mail, Globe } from 'lucide-react'; import { peerAPI } from '../services/api'; function formatBytes(bytes) { @@ -114,6 +114,38 @@ export default function PeerDashboard() {

Quick Access

+ {peer.service_urls && Object.keys(peer.service_urls).length > 0 ? ( +
+ {peer.service_urls.calendar && ( + + + Calendar + + )} + {peer.service_urls.files && ( + + + Files + + )} + {peer.service_urls.mail && ( + + + Mail + + )} + {peer.service_urls.webdav && ( + + + WebDAV + + )} +
+ ) : null} { if (!window.confirm(`Remove peer "${peerName}"?`)) return; try { - await Promise.all([peerRegistryAPI.removePeer(peerName), wireguardAPI.removePeer({ name: peerName })]); + await peerRegistryAPI.removePeer(peerName); fetchPeers(); showToast(`Peer "${peerName}" removed.`); } catch { showToast('Failed to remove peer', 'error'); } @@ -299,7 +296,11 @@ PersistentKeepalive = ${peer.persistent_keepalive || 25}`; const handleConfigDownloaded = async (peerName) => { try { - await fetch(`/api/peers/${peerName}/clear-reinstall`, { method: 'POST', credentials: 'include' }); + await fetch(`/api/peers/${peerName}/clear-reinstall`, { + method: 'POST', + credentials: 'include', + headers: { 'X-CSRF-Token': getCsrfToken() || '' }, + }); setPeers(ps => ps.map(p => p.name === peerName ? { ...p, config_needs_reinstall: false } : p)); } catch {} }; diff --git a/webui/src/pages/WireGuard.jsx b/webui/src/pages/WireGuard.jsx index 8fddc19..f0ed397 100644 --- a/webui/src/pages/WireGuard.jsx +++ b/webui/src/pages/WireGuard.jsx @@ -29,11 +29,11 @@ function WireGuard() { setIsRefreshingIp(true); try { // Refresh IP first (fast) - const ipResp = await fetch('/api/wireguard/refresh-ip', { method: 'POST', credentials: 'include' }); + const ipResp = await fetch('/api/wireguard/refresh-ip', { credentials: 'include' }); const ipData = await ipResp.json(); setServerConfig(prev => ({ ...prev, ...ipData, port_open: 'checking' })); // Then check port (slow — external call) - const portResp = await fetch('/api/wireguard/check-port', { method: 'POST', credentials: 'include' }); + const portResp = await fetch('/api/wireguard/check-port', { credentials: 'include' }); const portData = await portResp.json(); setServerConfig(prev => ({ ...prev, port_open: portData.port_open })); } catch (e) { @@ -56,7 +56,7 @@ function WireGuard() { if (serverConfigResponse) { setServerConfig({ ...serverConfigResponse, port_open: 'checking' }); // Check port asynchronously so page loads fast - fetch('/api/wireguard/check-port', { method: 'POST', credentials: 'include' }) + fetch('/api/wireguard/check-port', { credentials: 'include' }) .then(r => r.json()) .then(d => setServerConfig(prev => ({ ...prev, port_open: d.port_open ?? false }))) .catch(() => setServerConfig(prev => ({ ...prev, port_open: false }))); @@ -66,26 +66,29 @@ function WireGuard() { const peersData = peersResponse.data || []; const wireguardPeers = wireguardResponse.data || []; - // Create a map of WireGuard peers by name for quick lookup + // Create a map of WireGuard peers by public_key for quick lookup const wireguardMap = {}; wireguardPeers.forEach(peer => { - wireguardMap[peer.name] = peer; + if (peer.public_key) wireguardMap[peer.public_key] = peer; }); - + // Merge the data - const mergedPeers = peersData.map(peer => ({ - ...peer, - ...wireguardMap[peer.peer || peer.name], - name: peer.peer || peer.name, - status: 'Online', // For now, assume all peers are online - type: 'WireGuard', - // Preserve important fields that might be overwritten - private_key: peer.private_key, - server_public_key: peer.server_public_key, - server_endpoint: peer.server_endpoint, - allowed_ips: peer.allowed_ips || wireguardMap[peer.peer || peer.name]?.AllowedIPs || '0.0.0.0/0', - persistent_keepalive: peer.persistent_keepalive || wireguardMap[peer.peer || peer.name]?.PersistentKeepalive || 25 - })); + const mergedPeers = peersData.map(peer => { + const wgEntry = wireguardMap[peer.public_key] || {}; + return { + ...peer, + ...wgEntry, + // Registry fields always win over wg0.conf fields for name/keys/endpoint + name: peer.peer || peer.name, + type: 'WireGuard', + private_key: peer.private_key, + server_public_key: peer.server_public_key, + server_endpoint: peer.server_endpoint, + public_key: peer.public_key, + allowed_ips: peer.allowed_ips || wgEntry.allowed_ips || '0.0.0.0/0', + persistent_keepalive: peer.persistent_keepalive || wgEntry.persistent_keepalive || 25, + }; + }); // Load all peer statuses in one call (keyed by public_key) let liveStatuses = {}; diff --git a/webui/src/services/api.js b/webui/src/services/api.js index e1fa5bd..66f6dc9 100644 --- a/webui/src/services/api.js +++ b/webui/src/services/api.js @@ -1,5 +1,20 @@ import axios from 'axios'; +// Module-level CSRF token — populated after login or token refresh +let _csrfToken = null; + +/** + * Update the module-level CSRF token. + * Call this after a successful login with the token returned in the response body. + */ +export function setCsrfToken(token) { + _csrfToken = token; +} + +export function getCsrfToken() { + return _csrfToken; +} + // Create axios instance with base configuration const api = axios.create({ baseURL: import.meta.env.VITE_API_URL || '', @@ -10,10 +25,16 @@ const api = axios.create({ }, }); -// Request interceptor for logging +// Request interceptor — logging + CSRF header injection api.interceptors.request.use( (config) => { console.log(`API Request: ${config.method?.toUpperCase()} ${config.url}`); + // Attach CSRF token for all state-changing methods + const method = (config.method || 'get').toLowerCase(); + if (['post', 'put', 'delete', 'patch'].includes(method) && _csrfToken) { + config.headers = config.headers || {}; + config.headers['X-CSRF-Token'] = _csrfToken; + } return config; }, (error) => { @@ -22,13 +43,36 @@ api.interceptors.request.use( } ); -// Response interceptor for error handling +// Response interceptor — error handling + CSRF token refresh on 403 api.interceptors.response.use( (response) => { return response; }, - (error) => { + async (error) => { console.error('API Response Error:', error.response?.data || error.message); + + // Handle CSRF token expiry: refresh the token and retry the original request once + if ( + error.response?.status === 403 && + error.response?.data?.error === 'CSRF token missing or invalid' && + !error.config._csrfRetry + ) { + try { + const refreshResp = await api.get('/api/auth/csrf-token'); + const newToken = refreshResp.data?.csrf_token; + if (newToken) { + setCsrfToken(newToken); + // Retry the original request with the new token + const retryConfig = { ...error.config, _csrfRetry: true }; + retryConfig.headers = retryConfig.headers || {}; + retryConfig.headers['X-CSRF-Token'] = newToken; + return api(retryConfig); + } + } catch (refreshErr) { + console.error('CSRF token refresh failed:', refreshErr); + } + } + if ( error.response?.status === 401 && !error.config.url.includes('/auth/login') && @@ -107,12 +151,19 @@ export const peerRegistryAPI = { // Auth API export const authAPI = { - login: (username, password) => api.post('/api/auth/login', { username, password }), + login: async (username, password) => { + const response = await api.post('/api/auth/login', { username, password }); + if (response.data?.csrf_token) { + setCsrfToken(response.data.csrf_token); + } + return response; + }, logout: () => api.post('/api/auth/logout'), me: () => api.get('/api/auth/me'), changePassword: (old_password, new_password) => api.post('/api/auth/change-password', { old_password, new_password }), adminResetPassword: (username, new_password) => api.post('/api/auth/admin/reset-password', { username, new_password }), listUsers: () => api.get('/api/auth/users'), + getCsrfToken: () => api.get('/api/auth/csrf-token'), }; // Peer-facing dashboard API