diff --git a/api/app.py b/api/app.py index 9e4cb76..5285b01 100644 --- a/api/app.py +++ b/api/app.py @@ -14,9 +14,11 @@ Provides REST API endpoints for managing: import os import io import json +import stat import zipfile import shutil import logging +import secrets from datetime import datetime from flask import Flask, request, jsonify, current_app, send_file, session from flask_cors import CORS @@ -107,11 +109,33 @@ logger = logging.getLogger('picell') # Flask app setup app = Flask(__name__) -CORS(app) +CORS(app, + supports_credentials=True, + origins=['http://localhost', 'http://localhost:5173', 'http://localhost:8081', + 'http://127.0.0.1', 'http://127.0.0.1:5173', 'http://127.0.0.1:8081']) # Development mode flag app.config['DEVELOPMENT_MODE'] = True # Set to True for development, False for production -app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', os.urandom(32)) + +# Persist SECRET_KEY so sessions survive API restarts +SECRET_KEY_FILE = os.path.join(os.environ.get('DATA_DIR', '/app/data'), '.flask_secret_key') +if os.environ.get('SECRET_KEY'): + _flask_secret = os.environ['SECRET_KEY'].encode() if isinstance(os.environ['SECRET_KEY'], str) else os.environ['SECRET_KEY'] +elif os.path.exists(SECRET_KEY_FILE) and os.path.getsize(SECRET_KEY_FILE) > 0: + with open(SECRET_KEY_FILE, 'rb') as _skf: + _flask_secret = _skf.read() +else: + _flask_secret = os.urandom(32) + try: + os.makedirs(os.path.dirname(SECRET_KEY_FILE), exist_ok=True) + _skf_fd = os.open(SECRET_KEY_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(_skf_fd, 'wb') as _skf: + _skf.write(_flask_secret) + except OSError as _e: + logger.warning(f"Could not persist SECRET_KEY to disk: {_e}") +app.config['SECRET_KEY'] = _flask_secret +app.config['SESSION_COOKIE_HTTPONLY'] = True +app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # Initialize enhanced components config_manager = ConfigManager( @@ -183,13 +207,29 @@ def enforce_auth(): # Always allow non-API paths and auth namespace if not path.startswith('/api/') or path.startswith('/api/auth/'): return None - # Only enforce when auth_manager has been properly initialised and seeded + # Only enforce when auth_manager has been properly initialised and seeded. + # When the user store is empty (file missing or unreadable — typical in + # unit tests and fresh installs), bypass enforcement so pre-auth test + # suites continue to work. 503 is only returned when the users file + # exists and is readable but contains no accounts (explicit misconfiguration). try: from auth_manager import AuthManager as _AuthManager if not isinstance(auth_manager, _AuthManager): return None users = auth_manager.list_users() if not users: + # Only fail closed when the auth file is readable but empty — + # that's an explicit misconfiguration. If the file is missing or + # unreadable (test env, wrong host path, permission denied), bypass + # so pre-auth test suites continue to work. + users_file = getattr(auth_manager, '_users_file', None) + if users_file: + try: + with open(users_file, 'r') as _f: + _f.read(1) + return jsonify({'error': 'Authentication not configured. Set admin password first.'}), 503 + except (PermissionError, FileNotFoundError, OSError): + return None return None except Exception: return None @@ -206,6 +246,28 @@ def enforce_auth(): return None +@app.before_request +def check_csrf(): + """Double-submit CSRF protection for state-changing API requests. + + Applies to POST/PUT/DELETE/PATCH on /api/* paths, excluding /api/auth/*. + Skipped entirely when app.config['TESTING'] is True so unit tests remain + unaffected without needing to set CSRF headers. + """ + if app.config.get('TESTING'): + return None + if request.method not in ('POST', 'PUT', 'DELETE', 'PATCH'): + return None + path = request.path + if not path.startswith('/api/') or path.startswith('/api/auth/'): + return None + token_header = request.headers.get('X-CSRF-Token') + token_session = session.get('csrf_token') + if not token_header or token_header != token_session: + return jsonify({'error': 'CSRF token missing or invalid'}), 403 + return None + + @app.after_request def log_request(response): ctx = request_context.get({}) @@ -246,7 +308,8 @@ def _apply_startup_enforcement(): try: peers = peer_registry.list_peers() firewall_manager.apply_all_peer_rules(peers) - firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) logger.info(f"Applied enforcement rules for {len(peers)} peers on startup") except Exception as e: logger.warning(f"Startup enforcement failed (non-fatal): {e}") @@ -418,20 +481,16 @@ def is_local_request(): ip = _ipa.ip_address(addr.strip()) if ip.is_loopback: return True - # RFC-1918 private ranges - for _rfc in ('10.0.0.0/8', '172.16.0.0/12', '192.168.0.0/16'): - if ip in _ipa.ip_network(_rfc): - return True + # Only trust loopback and Docker bridge (172.16.0.0/12). + # Deliberately excludes 10.0.0.0/8 (WireGuard peer subnet) and + # 192.168.0.0/16 (LAN) — VPN peers must not access local-only endpoints. + if ip in _ipa.ip_network('172.16.0.0/12'): + return True # Any subnet the container is directly attached to (handles non-RFC-1918 # Docker bridge networks such as 172.0.0.0/24). for _net in _local_subnets(): if ip in _net: return True - # Configured cell ip_range (WireGuard peer subnet) - _cell = config_manager.configs.get('_identity', {}).get( - 'ip_range', os.environ.get('CELL_IP_RANGE', '172.20.0.0/16')) - if ip in _ipa.ip_network(_cell, strict=False): - return True except Exception: pass return False @@ -537,21 +596,31 @@ def update_config(): identity_keys = {'cell_name', 'domain', 'ip_range', 'wireguard_port'} identity_updates = {k: v for k, v in data.items() if k in identity_keys} - # Validate cell_name — must be non-empty and at most 255 characters (DNS limit) + # Validate cell_name and domain — block injection characters while + # allowing the full range of valid hostname/domain characters. + import re as _re_cfg + # cell_name: hostname component — letters, digits, hyphens only (no dots) + _CELL_NAME_RE = _re_cfg.compile(r'^[a-zA-Z0-9][a-zA-Z0-9-]{0,254}$') + # domain: may include dots for multi-label names (e.g. home.lan) + _DOMAIN_RE = _re_cfg.compile(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,254}$') + if 'cell_name' in identity_updates: v = str(identity_updates['cell_name']) - if len(v) > 255: - return jsonify({'error': 'cell_name must be 255 characters or fewer'}), 400 if not v: return jsonify({'error': 'cell_name cannot be empty'}), 400 + if len(v) > 255: + return jsonify({'error': 'cell_name must be 255 characters or fewer'}), 400 + if not _CELL_NAME_RE.match(v): + return jsonify({'error': 'Invalid cell_name: use only letters, digits, hyphens'}), 400 - # Validate domain — must be non-empty and at most 255 characters (DNS limit) if 'domain' in identity_updates: v = str(identity_updates['domain']) - if len(v) > 255: - return jsonify({'error': 'domain must be 255 characters or fewer'}), 400 if not v: return jsonify({'error': 'domain cannot be empty'}), 400 + if len(v) > 255: + return jsonify({'error': 'domain must be 255 characters or fewer'}), 400 + if not _DOMAIN_RE.match(v): + return jsonify({'error': 'Invalid domain: use only letters, digits, hyphens, dots'}), 400 # Validate ip_range — must be a valid CIDR within an RFC-1918 range if 'ip_range' in identity_updates: @@ -686,7 +755,7 @@ def update_config(): _cur_id = config_manager.configs.get('_identity', {}) _cur_range = _cur_id.get('ip_range', os.environ.get('CELL_IP_RANGE', '172.20.0.0/16')) _cur_name = _cur_id.get('cell_name', os.environ.get('CELL_NAME', 'mycell')) - _ip_domain.write_caddyfile(_cur_range, _cur_name, domain, '/app/config/caddy/Caddyfile') + _ip_domain.write_caddyfile(_cur_range, _cur_name, domain, '/app/config-caddy/Caddyfile') _set_pending_restart( [f'domain changed to {domain}'], ['dns', 'caddy'], @@ -705,7 +774,7 @@ def update_config(): _cur_id2 = config_manager.configs.get('_identity', {}) _cur_range2 = _cur_id2.get('ip_range', os.environ.get('CELL_IP_RANGE', '172.20.0.0/16')) _cur_domain2 = identity_updates.get('domain') or _cur_id2.get('domain', os.environ.get('CELL_DOMAIN', 'cell')) - _ip_name.write_caddyfile(_cur_range2, new_name, _cur_domain2, '/app/config/caddy/Caddyfile') + _ip_name.write_caddyfile(_cur_range2, new_name, _cur_domain2, '/app/config-caddy/Caddyfile') _set_pending_restart( [f'cell_name changed to {new_name}'], ['dns'], @@ -731,7 +800,7 @@ def update_config(): ip_utils.write_env_file(new_range, env_file, _collect_service_ports(config_manager.configs)) # Regenerate Caddyfile with new VIPs ip_utils.write_caddyfile(new_range, cur_cell_name, cur_domain, - '/app/config/caddy/Caddyfile') + '/app/config-caddy/Caddyfile') # Mark ALL containers as needing restart; network_recreate signals that # docker compose down is required before up (Docker can't change subnet in-place) _set_pending_restart( @@ -934,7 +1003,7 @@ def cancel_pending_config(): if cur_cell_name and old_cell_name and cur_cell_name != old_cell_name: network_manager.apply_cell_name(cur_cell_name, old_cell_name, reload=False) - _ip_revert.write_caddyfile(_range, _cell, _dom, '/app/config/caddy/Caddyfile') + _ip_revert.write_caddyfile(_range, _cell, _dom, '/app/config-caddy/Caddyfile') _clear_pending_restart() return jsonify({'message': 'Pending changes discarded'}) @@ -966,9 +1035,6 @@ def apply_pending_config(): containers = pending.get('containers', ['*']) - # Clear pending flag before we restart so it shows cleared after new containers start - _clear_pending_restart() - # Check if the IP range (network subnet) is changing — Docker cannot modify an # existing network's subnet in-place, so we need `down` + `up` in that case. needs_network_recreate = pending.get('network_recreate', False) @@ -981,6 +1047,9 @@ def apply_pending_config(): # API container itself, killing this background thread mid-operation. # Spawn an independent helper container (same image as cell-api) that has docker # CLI and survives cell-api being stopped/recreated. + # Clear pending flag now — the helper runs fire-and-forget and we cannot track + # its exit code from within the API process (it may restart us). + _clear_pending_restart() if needs_network_recreate: helper_script = ( f'sleep 2' @@ -1015,6 +1084,8 @@ def apply_pending_config(): ) else: # Specific containers only — API is not affected, run directly from here. + # Only clear the pending flag after the subprocess exits with code 0 so that + # if the compose command fails the UI still shows changes as pending. def _do_apply(): import time as _time import subprocess as _subprocess @@ -1031,6 +1102,7 @@ def apply_pending_config(): logger.error(f"docker compose up failed: {result.stderr.strip()}") else: logger.info(f'docker compose up completed for: {containers}') + _clear_pending_restart() threading.Thread(target=_do_apply, daemon=False).start() @@ -1710,7 +1782,8 @@ def apply_wireguard_enforcement(): try: peers = peer_registry.list_peers() firewall_manager.apply_all_peer_rules(peers) - firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peers, COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) return jsonify({'ok': True, 'peers': len(peers)}) except Exception as e: return jsonify({'error': str(e)}), 500 @@ -1835,7 +1908,10 @@ def add_peer(): if len(password) < 10: return jsonify({"error": "password must be at least 10 characters"}), 400 - assigned_ip = data.get('ip') or _next_peer_ip() + try: + assigned_ip = data.get('ip') or _next_peer_ip() + except ValueError as e: + return jsonify({'error': str(e)}), 409 # Validate service_access if provided _valid_services = {'calendar', 'files', 'mail', 'webdav'} @@ -1882,33 +1958,51 @@ def add_peer(): 'config_needs_reinstall': False, } - success = peer_registry.add_peer(peer_info) - if success: - # Add peer to WireGuard server config (non-fatal if WG is not running) + peer_added_to_registry = False + try: + # Step 1: Add to registry + success = peer_registry.add_peer(peer_info) + if not success: + # Registry rejected (already exists) — rollback provisioned accounts + for svc in ('files', 'calendar', 'email', 'auth'): + try: + if svc == 'files': + file_manager.delete_user(peer_name) + elif svc == 'calendar': + calendar_manager.delete_calendar_user(peer_name) + elif svc == 'email': + email_manager.delete_email_user(peer_name, _configured_domain()) + elif svc == 'auth': + auth_manager.delete_user(peer_name) + except Exception: + pass + return jsonify({"error": f"Peer {peer_name} already exists"}), 400 + peer_added_to_registry = True + + # Step 2: Firewall rules (critical) + firewall_manager.apply_peer_rules(peer_info['ip'], peer_info) + + # Step 3: Add peer to WireGuard server config (non-fatal if WG is not running) wg_allowed = f"{assigned_ip}/32" if '/' not in assigned_ip else assigned_ip try: wireguard_manager.add_peer(peer_name, data['public_key'], endpoint_ip='', allowed_ips=wg_allowed) except Exception as wg_err: logger.warning(f"Peer {peer_name}: WireGuard server config update failed (non-fatal): {wg_err}") - # Apply server-side enforcement immediately - firewall_manager.apply_peer_rules(peer_info['ip'], peer_info) - firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain()) + + # Step 4: Update DNS rules + firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) return jsonify({"message": f"Peer {peer_name} added successfully", "ip": assigned_ip}), 201 - else: - # Registry rejected (already exists) — rollback provisioned accounts - for svc in ('files', 'calendar', 'email', 'auth'): + + except Exception as e: + # Rollback registry entry if we got past that step + if peer_added_to_registry: try: - if svc == 'files': - file_manager.delete_user(peer_name) - elif svc == 'calendar': - calendar_manager.delete_calendar_user(peer_name) - elif svc == 'email': - email_manager.delete_email_user(peer_name) - elif svc == 'auth': - auth_manager.delete_user(peer_name) + peer_registry.remove_peer(peer_name) except Exception: pass - return jsonify({"error": f"Peer {peer_name} already exists"}), 400 + logger.error(f"Error adding peer {peer_name}: {e}") + return jsonify({'error': str(e)}), 500 except Exception as e: logger.error(f"Error adding peer: {e}") @@ -1941,7 +2035,8 @@ def update_peer(peer_name): updated_peer = peer_registry.get_peer(peer_name) if updated_peer: firewall_manager.apply_peer_rules(updated_peer['ip'], updated_peer) - firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) result = {"message": f"Peer {peer_name} updated", "config_changed": config_changed} return jsonify(result) else: @@ -1974,7 +2069,8 @@ def remove_peer(peer_name): if success: if peer_ip: firewall_manager.clear_peer_rules(peer_ip) - firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain()) + firewall_manager.apply_all_dns_rules(peer_registry.list_peers(), COREFILE_PATH, _configured_domain(), + cell_links=cell_link_manager.list_connections()) # Remove peer from WireGuard server config (non-fatal) if peer_pubkey: try: @@ -1983,7 +2079,7 @@ def remove_peer(peer_name): logger.warning(f"Peer {peer_name}: WireGuard removal failed (non-fatal): {wg_err}") # Clean up all provisioned service accounts (best-effort) for _cleanup in [ - lambda: email_manager.delete_email_user(peer_name), + lambda: email_manager.delete_email_user(peer_name, _configured_domain()), lambda: calendar_manager.delete_calendar_user(peer_name), lambda: file_manager.delete_user(peer_name), lambda: auth_manager.delete_user(peer_name), @@ -2094,8 +2190,13 @@ def create_email_user(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = email_manager.create_user(data) - return jsonify(result) + username = data.get('username') + domain = data.get('domain') or _configured_domain() + password = data.get('password') + if not username or not password: + return jsonify({"error": "Missing required fields: username, password"}), 400 + result = email_manager.create_email_user(username, domain, password) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating email user: {e}") return jsonify({"error": str(e)}), 500 @@ -2104,8 +2205,9 @@ def create_email_user(): def delete_email_user(username): """Delete email user.""" try: - result = email_manager.delete_user(username) - return jsonify(result) + domain = request.args.get('domain') or _configured_domain() + result = email_manager.delete_email_user(username, domain) + return jsonify({"deleted": result}) except Exception as e: logger.error(f"Error deleting email user: {e}") return jsonify({"error": str(e)}), 500 @@ -2170,8 +2272,12 @@ def create_calendar_user(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = calendar_manager.create_user(data) - return jsonify(result) + username = data.get('username') + password = data.get('password') + if not username or not password: + return jsonify({"error": "Missing required fields: username, password"}), 400 + result = calendar_manager.create_calendar_user(username, password) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating calendar user: {e}") return jsonify({"error": str(e)}), 500 @@ -2180,8 +2286,8 @@ def create_calendar_user(): def delete_calendar_user(username): """Delete calendar user.""" try: - result = calendar_manager.delete_user(username) - return jsonify(result) + result = calendar_manager.delete_calendar_user(username) + return jsonify({"deleted": result}) except Exception as e: logger.error(f"Error deleting calendar user: {e}") return jsonify({"error": str(e)}), 500 @@ -2193,8 +2299,17 @@ def create_calendar(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = calendar_manager.create_calendar(data) - return jsonify(result) + username = data.get('username') + calendar_name = data.get('name') or data.get('calendar_name') + if not username or not calendar_name: + return jsonify({"error": "Missing required fields: username, name"}), 400 + result = calendar_manager.create_calendar( + username, + calendar_name, + description=data.get('description', ''), + color=data.get('color', '#4285f4'), + ) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating calendar: {e}") return jsonify({"error": str(e)}), 500 @@ -2205,8 +2320,13 @@ def add_calendar_event(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = calendar_manager.add_event(data) - return jsonify(result) + username = data.get('username') + calendar_name = data.get('calendar_name') or data.get('calendar') + if not username or not calendar_name: + return jsonify({"error": "Missing required fields: username, calendar_name"}), 400 + event_data = {k: v for k, v in data.items() if k not in ('username', 'calendar_name', 'calendar')} + result = calendar_manager.add_event(username, calendar_name, event_data) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error adding calendar event: {e}") return jsonify({"error": str(e)}), 500 @@ -2260,8 +2380,12 @@ def create_file_user(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = file_manager.create_user(data) - return jsonify(result) + username = data.get('username') + password = data.get('password') + if not username or not password: + return jsonify({"error": "Missing required fields: username, password"}), 400 + result = file_manager.create_user(username, password) + return jsonify({"created": result}) except Exception as e: logger.error(f"Error creating file user: {e}") return jsonify({"error": str(e)}), 500 @@ -2283,8 +2407,12 @@ def create_folder(): data = request.get_json(silent=True) if data is None: return jsonify({"error": "No data provided"}), 400 - result = file_manager.create_folder(data) - return jsonify(result) + username = data.get('username') + folder_path = data.get('folder_path') or data.get('path') + if not username or not folder_path: + return jsonify({"error": "Missing required fields: username, folder_path"}), 400 + result = file_manager.create_folder(username, folder_path) + return jsonify({"created": result}) except ValueError as e: return jsonify({"error": str(e)}), 400 except Exception as e: @@ -2309,12 +2437,13 @@ def upload_file(username): try: if 'file' not in request.files: return jsonify({"error": "No file provided"}), 400 - + file = request.files['file'] - path = request.form.get('path', '') - - result = file_manager.upload_file(username, file, path) - return jsonify(result) + path = request.form.get('path', '') or file.filename or '' + file_data = file.read() + + result = file_manager.upload_file(username, path, file_data) + return jsonify({"uploaded": result}) except ValueError as e: return jsonify({"error": str(e)}), 400 except Exception as e: @@ -2442,9 +2571,15 @@ def remove_nat_rule(rule_id): def add_peer_route(): """Add peer route.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_peer_route(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + peer_name = data.get('peer_name') + peer_ip = data.get('peer_ip') + allowed_networks = data.get('allowed_networks', []) + route_type = data.get('route_type', 'lan') + if not peer_name or not peer_ip: + return jsonify({"error": "Missing required fields: peer_name, peer_ip"}), 400 + result = routing_manager.add_peer_route(peer_name, peer_ip, allowed_networks, route_type) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding peer route: {e}") return jsonify({"error": str(e)}), 500 @@ -2463,9 +2598,13 @@ def remove_peer_route(peer_name): def add_exit_node(): """Add exit node.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_exit_node(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + peer_name = data.get('peer_name') + peer_ip = data.get('peer_ip') + if not peer_name or not peer_ip: + return jsonify({"error": "Missing required fields: peer_name, peer_ip"}), 400 + result = routing_manager.add_exit_node(peer_name, peer_ip, data.get('allowed_domains')) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding exit node: {e}") return jsonify({"error": str(e)}), 500 @@ -2474,9 +2613,14 @@ def add_exit_node(): def add_bridge_route(): """Add bridge route.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_bridge_route(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + source_peer = data.get('source_peer') + target_peer = data.get('target_peer') + allowed_networks = data.get('allowed_networks', []) + if not source_peer or not target_peer: + return jsonify({"error": "Missing required fields: source_peer, target_peer"}), 400 + result = routing_manager.add_bridge_route(source_peer, target_peer, allowed_networks) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding bridge route: {e}") return jsonify({"error": str(e)}), 500 @@ -2485,9 +2629,13 @@ def add_bridge_route(): def add_split_route(): """Add split route.""" try: - data = request.get_json(silent=True) - result = routing_manager.add_split_route(data) - return jsonify(result) + data = request.get_json(silent=True) or {} + network = data.get('network') + exit_peer = data.get('exit_peer') + if not network or not exit_peer: + return jsonify({"error": "Missing required fields: network, exit_peer"}), 400 + result = routing_manager.add_split_route(network, exit_peer, data.get('fallback_peer')) + return jsonify({"added": result}) except Exception as e: logger.error(f"Error adding split route: {e}") return jsonify({"error": str(e)}), 500 @@ -2985,6 +3133,12 @@ def create_container(): volumes = data.get('volumes', {}) command = data.get('command', '') ports = data.get('ports', {}) + if volumes: + allowed_prefixes = ('/home/roof/pic/data/', '/home/roof/pic/config/', '/tmp/') + for host_path in volumes.keys(): + resolved = os.path.realpath(str(host_path)) + if not any(resolved.startswith(p) for p in allowed_prefixes): + return jsonify({'error': f'Volume mount not allowed: {host_path}'}), 403 result = container_manager.create_container( image=data['image'], name=name, diff --git a/api/auth_routes.py b/api/auth_routes.py index 10c427c..f2a268f 100644 --- a/api/auth_routes.py +++ b/api/auth_routes.py @@ -8,6 +8,7 @@ after instantiation. A ``require_auth(role=None)`` decorator is also exported so individual routes can opt-in to specific role requirements. """ +import secrets from functools import wraps from flask import Blueprint, request, jsonify, session @@ -80,11 +81,13 @@ def login(): session['username'] = user['username'] session['role'] = user.get('role') session['peer_name'] = user.get('peer_name') + session['csrf_token'] = secrets.token_hex(32) return jsonify({ 'username': user['username'], 'role': user.get('role'), 'peer_name': user.get('peer_name'), 'must_change_password': bool(user.get('must_change_password', False)), + 'csrf_token': session['csrf_token'], }) @@ -143,6 +146,16 @@ def admin_reset_password(): return jsonify({'ok': True}) +@auth_bp.route('/csrf-token', methods=['GET']) +def get_csrf_token(): + """Return the current session's CSRF token, generating one if absent.""" + token = session.get('csrf_token') + if not token: + token = secrets.token_hex(32) + session['csrf_token'] = token + return jsonify({'csrf_token': token}) + + @auth_bp.route('/users', methods=['GET']) @require_auth('admin') def list_users(): diff --git a/api/base_service_manager.py b/api/base_service_manager.py index 142074f..42c349f 100644 --- a/api/base_service_manager.py +++ b/api/base_service_manager.py @@ -65,10 +65,20 @@ class BaseServiceManager(ABC): return [f"Error reading logs: {str(e)}"] def restart_service(self) -> bool: - """Restart service - default implementation""" + """Restart service - default implementation. + + Delegates to _restart_container() using self.container_name when set, + otherwise falls back to self.service_name. Subclasses with a known + container name should set self.container_name in their __init__ or + override this method entirely. + """ try: - self.logger.info(f"Restarting {self.service_name} service") - return True + name = getattr(self, 'container_name', None) or self.service_name + if not name: + self.logger.warning("restart_service: no container name available; skipping restart") + return False + self.logger.info(f"Restarting {self.service_name} service via container '{name}'") + return self._restart_container(name) except Exception as e: self.logger.error(f"Error restarting {self.service_name}: {e}") return False diff --git a/api/calendar_manager.py b/api/calendar_manager.py index ee6dd2c..c21deac 100644 --- a/api/calendar_manager.py +++ b/api/calendar_manager.py @@ -255,9 +255,14 @@ class CalendarManager(BaseServiceManager): return False # Create new user + # SECURITY: Do NOT persist the plaintext password here. The calendar + # password is the same as the user's VPN auth password and storing + # it in plain JSON would leak every user credential if this file is + # read. Auth verification goes through auth_manager; the actual + # CalDAV/CardDAV auth is handled by the cell-radicale container + # (htpasswd file). This JSON is metadata only. new_user = { 'username': username, - 'password': password, # In production, this should be hashed 'calendars_count': 0, 'events_count': 0, 'created_at': datetime.utcnow().isoformat(), @@ -267,11 +272,14 @@ class CalendarManager(BaseServiceManager): users.append(new_user) self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Create user directory user_dir = os.path.join(self.calendar_data_dir, 'users', username) self.safe_makedirs(user_dir) - + logger.info(f"Created calendar user: {username}") return True except Exception as e: @@ -288,13 +296,16 @@ class CalendarManager(BaseServiceManager): if user.get('username') == username: del users[i] self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Remove user directory user_dir = os.path.join(self.calendar_data_dir, 'users', username) if os.path.exists(user_dir): import shutil shutil.rmtree(user_dir) - + logger.info(f"Deleted calendar user: {username}") return True @@ -446,11 +457,31 @@ class CalendarManager(BaseServiceManager): except Exception as e: return self.handle_error(e, "get_metrics") + def _sync_users_to_cell_config(self): + """Best-effort sync of the calendar user list into cell_config.json via ConfigManager. + + Only safe metadata (no passwords) is written. Failures are logged as + warnings so they never block the per-service operation that triggered them. + """ + try: + from config_manager import ConfigManager + cm = ConfigManager() + _SENSITIVE = {'password', 'hashed_password', 'password_hash'} + safe_users = [ + {k: v for k, v in u.items() if k not in _SENSITIVE} + for u in self._load_users() + ] + existing = cm.get_service_config('calendar') + existing['users'] = safe_users + cm.update_service_config('calendar', existing) + except Exception as e: + self.logger.warning(f"Failed to sync calendar users to cell_config.json: {e}") + def restart_service(self) -> bool: - """Restart calendar service""" + """Restart calendar service (restarts the cell-radicale Docker container).""" try: logger.info('Calendar service restart requested') - return True + return self._restart_container('cell-radicale') except Exception as e: logger.error(f'Failed to restart calendar service: {e}') return False diff --git a/api/config_manager.py b/api/config_manager.py index ae12ad2..763dad6 100644 --- a/api/config_manager.py +++ b/api/config_manager.py @@ -14,6 +14,9 @@ from typing import Dict, List, Optional, Any from pathlib import Path import logging +# The Caddyfile lives on a separate volume mount from the rest of config +LIVE_CADDYFILE = os.environ.get('CADDYFILE_PATH', '/app/config-caddy/Caddyfile') + logger = logging.getLogger(__name__) class ConfigManager: @@ -216,7 +219,7 @@ class ConfigManager: env_file = Path(os.environ.get('ENV_FILE', '/app/.env')) extra = [ - (config_dir / 'caddy' / 'Caddyfile', 'Caddyfile'), + (Path(LIVE_CADDYFILE), 'Caddyfile'), (config_dir / 'dns' / 'Corefile', 'Corefile'), (env_file, '.env'), ] @@ -288,7 +291,7 @@ class ConfigManager: env_file = Path(os.environ.get('ENV_FILE', '/app/.env')) restore_map = [ - (backup_path / 'Caddyfile', config_dir / 'caddy' / 'Caddyfile'), + (backup_path / 'Caddyfile', Path(LIVE_CADDYFILE)), (backup_path / 'Corefile', config_dir / 'dns' / 'Corefile'), (backup_path / '.env', env_file), ] diff --git a/api/email_manager.py b/api/email_manager.py index dd5a0c5..5663567 100644 --- a/api/email_manager.py +++ b/api/email_manager.py @@ -299,11 +299,16 @@ class EmailManager(BaseServiceManager): return False # Create new user + # SECURITY: Do NOT persist the plaintext password here. The email + # password is the same as the user's VPN auth password and storing + # it in plain JSON would leak every user credential if this file + # is read. Auth verification goes through auth_manager; the actual + # mailbox auth is handled by the cell-mail container (Dovecot), + # which has its own credential store. This JSON is metadata only. new_user = { 'username': username, 'domain': domain, 'email': f'{username}@{domain}', - 'password': password, # In production, this should be hashed 'quota_limit': quota_limit, 'quota_used': 0, 'created_at': datetime.utcnow().isoformat(), @@ -313,11 +318,14 @@ class EmailManager(BaseServiceManager): users.append(new_user) self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Create user mailbox directory mailbox_dir = os.path.join(self.email_data_dir, 'mailboxes', f'{username}@{domain}') self.safe_makedirs(mailbox_dir) - + logger.info(f"Created email user: {username}@{domain}") return True except Exception as e: @@ -334,13 +342,16 @@ class EmailManager(BaseServiceManager): if user.get('username') == username and user.get('domain') == domain: del users[i] self._save_users(users) - + + # Sync user list to cell_config.json (best-effort, non-fatal) + self._sync_users_to_cell_config() + # Remove user mailbox directory mailbox_dir = os.path.join(self.email_data_dir, 'mailboxes', f'{username}@{domain}') if os.path.exists(mailbox_dir): import shutil shutil.rmtree(mailbox_dir) - + logger.info(f"Deleted email user: {username}@{domain}") return True @@ -408,11 +419,35 @@ class EmailManager(BaseServiceManager): except Exception as e: return self.handle_error(e, "get_metrics") + def _sync_users_to_cell_config(self): + """Best-effort sync of the email user list into cell_config.json via ConfigManager. + + Only safe metadata (no passwords) is written. Failures are logged as + warnings so they never block the per-service operation that triggered them. + """ + try: + # Import here to avoid circular imports and to tolerate environments + # where config_manager is not on sys.path. + from config_manager import ConfigManager + cm = ConfigManager() + # Build safe user list: strip any sensitive keys that should not + # land in the shared config file. + _SENSITIVE = {'password', 'hashed_password', 'password_hash'} + safe_users = [ + {k: v for k, v in u.items() if k not in _SENSITIVE} + for u in self._load_users() + ] + existing = cm.get_service_config('email') + existing['users'] = safe_users + cm.update_service_config('email', existing) + except Exception as e: + self.logger.warning(f"Failed to sync email users to cell_config.json: {e}") + def restart_service(self) -> bool: - """Restart email service""" + """Restart email service (restarts the cell-mail Docker container).""" try: logger.info('Email service restart requested') - return True + return self._restart_container('cell-mail') except Exception as e: logger.error(f'Failed to restart email service: {e}') return False diff --git a/api/file_manager.py b/api/file_manager.py index 256f1ba..2d555df 100644 --- a/api/file_manager.py +++ b/api/file_manager.py @@ -14,6 +14,7 @@ from datetime import datetime from typing import Dict, List, Optional, Tuple, Any import shutil import hashlib +import bcrypt from base_service_manager import BaseServiceManager logger = logging.getLogger(__name__) @@ -103,9 +104,18 @@ umask = 022 if not username or not password: logger.error("Username and password must not be empty") return False + # Validate username — prevents path traversal in user_dir join below and + # injection of newlines / colons into the htpasswd-format auth file. + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"create_user: invalid username {username!r}") + return False try: - # Create user directory - user_dir = os.path.join(self.files_dir, username) + # Create user directory (containment check) + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"create_user: path traversal for username {username!r}") + return False os.makedirs(user_dir, exist_ok=True) # Create default folders @@ -115,8 +125,12 @@ umask = 022 # Add user to auth file auth_file = os.path.join(self.webdav_dir, 'users') - # Generate password hash - password_hash = hashlib.sha256(password.encode()).hexdigest() + # Generate bcrypt hash; convert $2b$ -> $2y$ for Apache htpasswd compatibility + # (bytemark/webdav is Apache-based; htpasswd-bcrypt uses $2y$ prefix). + bcrypt_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') + if bcrypt_hash.startswith('$2b$'): + bcrypt_hash = '$2y$' + bcrypt_hash[4:] + password_hash = bcrypt_hash with open(auth_file, 'a') as f: f.write(f"{username}:{password_hash}\n") @@ -133,6 +147,10 @@ umask = 022 if not username: logger.error("Username must not be empty") return False + # Validate username before any auth-file rewrite or filesystem ops + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"delete_user: invalid username {username!r}") + return False try: # Remove from auth file auth_file = os.path.join(self.webdav_dir, 'users') @@ -145,8 +163,13 @@ umask = 022 if not line.startswith(f"{username}:"): f.write(line) - # Remove user directory - user_dir = os.path.join(self.files_dir, username) + # Remove user directory — containment check prevents + # username='..' or 'foo/../../etc' from escaping files_dir. + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"delete_user: path traversal for username {username!r}") + return False if os.path.exists(user_dir): shutil.rmtree(user_dir) @@ -460,8 +483,15 @@ umask = 022 if not username or not backup_path: logger.error("Username and backup_path must not be empty") return False + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"backup_user_files: invalid username {username!r}") + return False try: - user_dir = os.path.join(self.files_dir, username) + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"backup_user_files: path traversal for username {username!r}") + return False if os.path.exists(user_dir): shutil.make_archive(backup_path, 'zip', user_dir) @@ -480,8 +510,15 @@ umask = 022 if not username or not backup_path: logger.error("Username and backup_path must not be empty") return False + if not isinstance(username, str) or not re.match(r'^[A-Za-z0-9._-]{1,64}$', username): + logger.error(f"restore_user_files: invalid username {username!r}") + return False try: - user_dir = os.path.join(self.files_dir, username) + user_dir = os.path.realpath(os.path.join(self.files_dir, username)) + files_root = os.path.realpath(self.files_dir) + if not user_dir.startswith(files_root + os.sep): + logger.error(f"restore_user_files: path traversal for username {username!r}") + return False # Remove existing user directory if os.path.exists(user_dir): diff --git a/api/firewall_manager.py b/api/firewall_manager.py index 01572c5..bf55c31 100644 --- a/api/firewall_manager.py +++ b/api/firewall_manager.py @@ -114,19 +114,32 @@ def _delete_rule(chain: str, rule_args: List[str]) -> None: # --------------------------------------------------------------------------- def _peer_comment(peer_ip: str) -> str: - return f'pic-peer-{peer_ip.replace(".", "-")}' + # SECURITY: append a non-numeric, non-dash suffix so peer comments cannot + # be substrings of one another. Without this, the comment for 10.0.0.1 + # ('pic-peer-10-0-0-1') is a prefix of 10.0.0.10..19 and a naive + # substring match would delete unrelated peers' rules. + return f'pic-peer-{peer_ip.replace(".", "-")}/32' def clear_peer_rules(peer_ip: str) -> None: """Remove all FORWARD rules tagged with this peer's IP via iptables-save/restore.""" comment = _peer_comment(peer_ip) + # SECURITY: match the comment as a complete --comment token, not a + # substring. iptables-save renders comments as `--comment ""` (or + # occasionally without quotes), so we anchor on the surrounding quotes / + # whitespace. Even with the unique /32 suffix in _peer_comment, we keep + # exact-token matching so a future change to the comment format cannot + # silently re-introduce the substring-deletion bug. + comment_re = re.compile( + rf'--comment\s+["\']?{re.escape(comment)}["\']?(\s|$)' + ) try: # Dump rules, strip matching lines, restore — atomic and order-stable save = _wg_exec(['iptables-save']) if save.returncode != 0: return lines = save.stdout.splitlines() - filtered = [l for l in lines if comment not in l] + filtered = [l for l in lines if not comment_re.search(l)] if len(filtered) == len(lines): return # nothing to remove restore_input = '\n'.join(filtered) + '\n' @@ -243,11 +256,15 @@ def _build_acl_block(blocked_peers_by_service: Dict[str, List[str]], def generate_corefile(peers: List[Dict[str, Any]], corefile_path: str = COREFILE_PATH, - domain: str = 'cell') -> bool: + domain: str = 'cell', + cell_links: Optional[List[Dict[str, Any]]] = None) -> bool: """ Rewrite the CoreDNS Corefile with per-peer ACL rules and reload plugin. The file is written to corefile_path (API-side path mapped into CoreDNS container). domain: the configured cell domain (e.g. 'cell', 'dev') — must match zone file names. + cell_links: optional list of cell-to-cell DNS forwarding entries, each a dict with + 'domain' and 'dns_ip' keys (same shape as CellLinkManager.list_connections()). + When non-empty, a forwarding stanza is appended for each entry. """ try: # Collect which peers block which services @@ -275,8 +292,25 @@ def generate_corefile(peers: List[Dict[str, Any]], corefile_path: str = COREFILE health }} -{primary_zone_block} -""" +{primary_zone_block}""" + + # Append cell-to-cell DNS forwarding stanzas if provided + if cell_links: + for link in cell_links: + link_domain = link.get('domain', '') + link_dns_ip = link.get('dns_ip', '') + if not link_domain or not link_dns_ip: + continue + corefile += ( + f'\n{link_domain} {{\n' + f' forward . {link_dns_ip}\n' + f' cache\n' + f' log\n' + f'}}\n' + ) + else: + corefile += '\n' + # local.{domain} block intentionally omitted: /data/local.zone does not exist # and CoreDNS logs errors on every reload for a missing zone file. os.makedirs(os.path.dirname(corefile_path), exist_ok=True) @@ -309,9 +343,10 @@ def reload_coredns() -> bool: def apply_all_dns_rules(peers: List[Dict[str, Any]], corefile_path: str = COREFILE_PATH, - domain: str = 'cell') -> bool: - """Regenerate Corefile and reload CoreDNS.""" - ok = generate_corefile(peers, corefile_path, domain) + domain: str = 'cell', + cell_links: Optional[List[Dict[str, Any]]] = None) -> bool: + """Regenerate Corefile (including any cell-to-cell forwarding stanzas) and reload CoreDNS.""" + ok = generate_corefile(peers, corefile_path, domain, cell_links) if ok: reload_coredns() return ok diff --git a/api/ip_utils.py b/api/ip_utils.py index d6dda4b..0b1dacb 100644 --- a/api/ip_utils.py +++ b/api/ip_utils.py @@ -204,12 +204,12 @@ http://webui.{domain} {{ }} """ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) - tmp = path + '.tmp' - with open(tmp, 'w') as f: + # Write in-place (same inode) so Docker bind-mounted files see the update. + # os.replace() changes the inode which breaks file bind-mounts inside containers. + with open(path, 'w') as f: f.write(content) f.flush() os.fsync(f.fileno()) - os.replace(tmp, path) return True except Exception: return False diff --git a/api/network_manager.py b/api/network_manager.py index be2efe3..5370cb6 100644 --- a/api/network_manager.py +++ b/api/network_manager.py @@ -29,8 +29,28 @@ class NetworkManager(BaseServiceManager): def update_dns_zone(self, zone_name: str, records: List[Dict]) -> bool: """Update DNS zone file with new records""" + # Validate zone_name — must be a safe DNS label, no path traversal + if not isinstance(zone_name, str) or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', zone_name): + logger.error(f"update_dns_zone: invalid zone_name {zone_name!r}") + return False try: zone_file = os.path.join(self.dns_zones_dir, f'{zone_name}.zone') + # Containment check: resolved zone_file must be inside dns_zones_dir + real_dir = os.path.realpath(self.dns_zones_dir) + real_zone = os.path.realpath(zone_file) + if not (real_zone == real_dir or real_zone.startswith(real_dir + os.sep)): + logger.error(f"update_dns_zone: path traversal attempt for zone {zone_name!r}") + return False + # Validate every record's name and value to prevent zone-file injection + for rec in records: + rname = rec.get('name', '') + rvalue = rec.get('value', '') + if rname and not re.match(r'^[a-zA-Z0-9_.*-]{1,253}$', str(rname)): + logger.error(f"update_dns_zone: invalid record name {rname!r}") + return False + if rvalue and not re.match(r'^[a-zA-Z0-9._: -]{1,512}$', str(rvalue)): + logger.error(f"update_dns_zone: invalid record value {rvalue!r}") + return False # Create zone file content content = self._generate_zone_content(zone_name, records) @@ -84,6 +104,16 @@ class NetworkManager(BaseServiceManager): def add_dns_record(self, zone: str, name: str, record_type: str, value: str, ttl: int = 3600) -> bool: """Add a DNS record to a zone""" + # Validate zone, name, and value to prevent injection / path traversal + if not isinstance(zone, str) or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', zone): + logger.error(f"add_dns_record: invalid zone {zone!r}") + return False + if not isinstance(name, str) or not re.match(r'^[a-zA-Z0-9_.*-]{1,253}$', name): + logger.error(f"add_dns_record: invalid name {name!r}") + return False + if not isinstance(value, str) or not re.match(r'^[a-zA-Z0-9._: -]{1,512}$', value): + logger.error(f"add_dns_record: invalid value {value!r}") + return False try: # Load existing records records = self._load_dns_records(zone) @@ -505,58 +535,75 @@ class NetworkManager(BaseServiceManager): warnings.append(f"cell_name DNS update failed: {e}") return {'restarted': restarted, 'warnings': warnings} + def _load_cell_links(self) -> List[Dict[str, Any]]: + """Load cell_links.json from the data directory (written by CellLinkManager).""" + links_file = os.path.join(self.data_dir, 'cell_links.json') + if os.path.exists(links_file): + try: + with open(links_file) as f: + return json.load(f) + except Exception: + return [] + return [] + def add_cell_dns_forward(self, domain: str, dns_ip: str) -> Dict[str, Any]: - """Append a CoreDNS forwarding block for a remote cell's domain.""" + """Register a CoreDNS forwarding entry for a remote cell's domain. + + Validates inputs, then rebuilds the entire Corefile via + firewall_manager.apply_all_dns_rules() so that no existing stanza is + silently wiped. Does NOT write the Corefile directly. + """ + import ipaddress + import firewall_manager as fm restarted = [] warnings = [] + # Validate dns_ip — newlines/garbage would inject arbitrary CoreDNS directives try: - corefile = os.path.join(self.config_dir, 'dns', 'Corefile') - if not os.path.exists(corefile): - warnings.append('Corefile not found') - return {'restarted': restarted, 'warnings': warnings} - with open(corefile) as f: - content = f.read() - marker = f'# cell:{domain}' - if marker in content: - return {'restarted': restarted, 'warnings': warnings} # already present - forward_block = ( - f'\n{marker}\n' - f'{domain} {{\n' - f' forward . {dns_ip}\n' - f' log\n' - f'}}\n' - ) - with open(corefile, 'a') as f: - f.write(forward_block) - self._reload_dns_service() + ipaddress.ip_address(dns_ip) + except (ValueError, TypeError): + warnings.append(f'add_cell_dns_forward: invalid dns_ip {dns_ip!r}') + return {'restarted': restarted, 'warnings': warnings} + # Validate domain — reject newlines, braces, spaces, and any non-DNS chars + if (not isinstance(domain, str) + or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', domain) + or any(c in domain for c in ('\n', '\r', '{', '}', ' ', '\t'))): + warnings.append(f'add_cell_dns_forward: invalid domain {domain!r}') + return {'restarted': restarted, 'warnings': warnings} + try: + # Build the full forwarding list: existing links + new entry (deduped by domain) + existing_links = self._load_cell_links() + # The new entry may not yet be in cell_links.json (CellLinkManager saves after + # calling us), so we merge it in here. + merged = [l for l in existing_links if l.get('domain') != domain] + merged.append({'domain': domain, 'dns_ip': dns_ip}) + + corefile_path = os.path.join(self.config_dir, 'dns', 'Corefile') + # Peers list is empty here; the full peer list is used by the periodic + # apply_all_dns_rules() call from app.py. We only need to persist the + # forwarding stanza without disturbing whatever peer ACLs are in the file. + fm.apply_all_dns_rules([], corefile_path=corefile_path, cell_links=merged) restarted.append('cell-dns (reloaded)') except Exception as e: warnings.append(f'add_cell_dns_forward failed: {e}') return {'restarted': restarted, 'warnings': warnings} def remove_cell_dns_forward(self, domain: str) -> Dict[str, Any]: - """Remove a CoreDNS forwarding block for a remote cell's domain.""" - import re + """Unregister a CoreDNS forwarding entry for a remote cell's domain. + + Rebuilds the entire Corefile via firewall_manager.apply_all_dns_rules() + with the named domain excluded. Does NOT write the Corefile directly. + """ + import firewall_manager as fm restarted = [] warnings = [] try: - corefile = os.path.join(self.config_dir, 'dns', 'Corefile') - if not os.path.exists(corefile): - return {'restarted': restarted, 'warnings': warnings} - with open(corefile) as f: - content = f.read() - marker = f'# cell:{domain}' - if marker not in content: - return {'restarted': restarted, 'warnings': warnings} - new_content = re.sub( - rf'\n# cell:{re.escape(domain)}\n{re.escape(domain)}\s*\{{[^}}]*\}}\n', - '', - content, - flags=re.DOTALL, - ) - with open(corefile, 'w') as f: - f.write(new_content) - self._reload_dns_service() + existing_links = self._load_cell_links() + # Exclude the domain being removed; CellLinkManager will also remove it + # from cell_links.json after this call returns. + remaining = [l for l in existing_links if l.get('domain') != domain] + + corefile_path = os.path.join(self.config_dir, 'dns', 'Corefile') + fm.apply_all_dns_rules([], corefile_path=corefile_path, cell_links=remaining) restarted.append('cell-dns (reloaded)') except Exception as e: warnings.append(f'remove_cell_dns_forward failed: {e}') diff --git a/api/peer_registry.py b/api/peer_registry.py index 941c484..86a14c2 100644 --- a/api/peer_registry.py +++ b/api/peer_registry.py @@ -1,341 +1,360 @@ -#!/usr/bin/env python3 -""" -Peer Registry for Personal Internet Cell -Handles peer registration and management -""" - -import json -import os -import logging -from threading import RLock -from datetime import datetime -from typing import Dict, List, Any, Optional -from base_service_manager import BaseServiceManager - -logger = logging.getLogger(__name__) - -class PeerRegistry(BaseServiceManager): - """Manages peer registration and management""" - - def __init__(self, data_dir: str = '/app/data', config_dir: str = '/app/config'): - super().__init__('peer_registry', data_dir, config_dir) - self.lock = RLock() - self.peers = [] - self.peers_file = os.path.join(data_dir, 'peers.json') - self._load_peers() - - def get_status(self) -> Dict[str, Any]: - """Get peer registry status""" - try: - with self.lock: - status = { - 'running': True, - 'status': 'online', - 'peers_count': len(self.peers), - 'active_peers': len([p for p in self.peers if p.get('active', True)]), - 'inactive_peers': len([p for p in self.peers if not p.get('active', True)]), - 'last_updated': datetime.utcnow().isoformat(), - 'timestamp': datetime.utcnow().isoformat() - } - - return status - except Exception as e: - return self.handle_error(e, "get_status") - - def test_connectivity(self) -> Dict[str, Any]: - """Test peer registry connectivity""" - try: - # Test file system access - fs_test = self._test_filesystem_access() - - # Test peer data integrity - integrity_test = self._test_data_integrity() - - # Test peer operations - operations_test = self._test_peer_operations() - - results = { - 'filesystem_access': fs_test, - 'data_integrity': integrity_test, - 'peer_operations': operations_test, - 'success': fs_test.get('success', False) and integrity_test.get('success', False), - 'timestamp': datetime.utcnow().isoformat() - } - - return results - except Exception as e: - return self.handle_error(e, "test_connectivity") - - def _test_filesystem_access(self) -> Dict[str, Any]: - """Test filesystem access for peer data""" - try: - # Test if we can read/write to the peers file - test_peer = { - 'peer': 'test_peer', - 'ip': '192.168.1.100', - 'public_key': 'test_key', - 'active': False, - 'test': True - } - - # Test write - with self.lock: - original_peers = self.peers.copy() - self.peers.append(test_peer) - self._save_peers() - - # Test read - with self.lock: - loaded_peers = self.peers.copy() - # Remove test peer - self.peers = [p for p in self.peers if not p.get('test', False)] - self._save_peers() - - # Restore original state - with self.lock: - self.peers = original_peers - self._save_peers() - - return { - 'success': True, - 'message': 'Filesystem access working', - 'read_write': True - } - except Exception as e: - return { - 'success': False, - 'message': f'Filesystem access failed: {str(e)}', - 'error': str(e) - } - - def _test_data_integrity(self) -> Dict[str, Any]: - """Test peer data integrity""" - try: - with self.lock: - # Check if peers data is valid JSON - peers_copy = self.peers.copy() - - # Validate peer structure - valid_peers = 0 - invalid_peers = 0 - - for peer in peers_copy: - if isinstance(peer, dict) and 'peer' in peer and 'ip' in peer: - valid_peers += 1 - else: - invalid_peers += 1 - - return { - 'success': True, - 'message': 'Data integrity check passed', - 'valid_peers': valid_peers, - 'invalid_peers': invalid_peers, - 'total_peers': len(peers_copy) - } - except Exception as e: - return { - 'success': False, - 'message': f'Data integrity check failed: {str(e)}', - 'error': str(e) - } - - def _test_peer_operations(self) -> Dict[str, Any]: - """Test peer operations""" - try: - # Test adding a peer - test_peer = { - 'peer': 'test_operation_peer', - 'ip': '192.168.1.101', - 'public_key': 'test_operation_key', - 'active': False, - 'test': True - } - - # Test add - add_success = self.add_peer(test_peer) - - # Test get - retrieved_peer = self.get_peer('test_operation_peer') - get_success = retrieved_peer is not None - - # Test update - update_success = self.update_peer_ip('test_operation_peer', '192.168.1.102') - - # Test remove - remove_success = self.remove_peer('test_operation_peer') - - return { - 'success': add_success and get_success and update_success and remove_success, - 'message': 'Peer operations working', - 'add_success': add_success, - 'get_success': get_success, - 'update_success': update_success, - 'remove_success': remove_success - } - except Exception as e: - return { - 'success': False, - 'message': f'Peer operations test failed: {str(e)}', - 'error': str(e) - } - - def _load_peers(self): - """Load peers from file""" - try: - # Ensure directory exists - os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) - - if os.path.exists(self.peers_file): - with open(self.peers_file, 'r') as f: - try: - self.peers = json.load(f) - self.logger.info(f"Loaded {len(self.peers)} peers from file") - except Exception as e: - self.logger.error(f"Error loading peers: {e}") - self.peers = [] - else: - self.peers = [] - self.logger.info("No peers file found, starting with empty registry") - except Exception as e: - self.logger.error(f"Error in _load_peers: {e}") - self.peers = [] - - def _save_peers(self): - """Save peers to file""" - try: - # Ensure directory exists - os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) - - with open(self.peers_file, 'w') as f: - json.dump(self.peers, f, indent=2) - - self.logger.info(f"Saved {len(self.peers)} peers to file") - except Exception as e: - self.logger.error(f"Error saving peers: {e}") - - def list_peers(self) -> List[Dict[str, Any]]: - """List all peers""" - with self.lock: - return list(self.peers) - - def get_peer(self, name: str) -> Optional[Dict[str, Any]]: - """Get a specific peer by name""" - with self.lock: - for peer in self.peers: - if peer.get('peer') == name: - return peer - return None - - def add_peer(self, peer_info: Dict[str, Any]) -> bool: - """Add a new peer""" - try: - with self.lock: - if self.get_peer(peer_info.get('peer')): - self.logger.warning(f"Peer {peer_info.get('peer')} already exists") - return False - - # Add timestamp - peer_info['created_at'] = datetime.utcnow().isoformat() - peer_info['active'] = peer_info.get('active', True) - - self.peers.append(peer_info) - self._save_peers() - - self.logger.info(f"Added peer: {peer_info.get('peer')}") - return True - except Exception as e: - self.logger.error(f"Error adding peer: {e}") - return False - - def remove_peer(self, name: str) -> bool: - """Remove a peer""" - try: - with self.lock: - before = len(self.peers) - self.peers = [p for p in self.peers if p.get('peer') != name] - self._save_peers() - - removed = len(self.peers) < before - if removed: - self.logger.info(f"Removed peer: {name}") - else: - self.logger.warning(f"Peer {name} not found for removal") - - return removed - except Exception as e: - self.logger.error(f"Error removing peer {name}: {e}") - return False - - def update_peer(self, name: str, fields: Dict[str, Any]) -> bool: - """Update arbitrary fields on a peer.""" - try: - with self.lock: - for peer in self.peers: - if peer.get('peer') == name: - peer.update(fields) - peer['updated_at'] = datetime.utcnow().isoformat() - self._save_peers() - self.logger.info(f"Updated peer {name}: {list(fields.keys())}") - return True - self.logger.warning(f"Peer {name} not found for update") - return False - except Exception as e: - self.logger.error(f"Error updating peer {name}: {e}") - return False - - def clear_reinstall_flag(self, name: str) -> bool: - """Clear the config_needs_reinstall flag after user downloads new config.""" - return self.update_peer(name, {'config_needs_reinstall': False}) - - def update_peer_ip(self, name: str, new_ip: str) -> bool: - """Update peer IP address""" - try: - with self.lock: - for peer in self.peers: - if peer.get('peer') == name: - old_ip = peer.get('ip') - peer['ip'] = new_ip - peer['updated_at'] = datetime.utcnow().isoformat() - self._save_peers() - - self.logger.info(f"Updated peer {name} IP from {old_ip} to {new_ip}") - return True - - self.logger.warning(f"Peer {name} not found for IP update") - return False - except Exception as e: - self.logger.error(f"Error updating peer {name} IP: {e}") - return False - - def get_peer_stats(self) -> Dict[str, Any]: - """Get peer registry statistics""" - try: - with self.lock: - active_peers = [p for p in self.peers if p.get('active', True)] - inactive_peers = [p for p in self.peers if not p.get('active', True)] - - # Count peers by IP range - ip_ranges = {} - for peer in self.peers: - ip = peer.get('ip', '') - if ip: - range_key = '.'.join(ip.split('.')[:3]) + '.0/24' - ip_ranges[range_key] = ip_ranges.get(range_key, 0) + 1 - - return { - 'total_peers': len(self.peers), - 'active_peers': len(active_peers), - 'inactive_peers': len(inactive_peers), - 'ip_ranges': ip_ranges, - 'timestamp': datetime.utcnow().isoformat() - } - except Exception as e: - self.logger.error(f"Error getting peer stats: {e}") - return { - 'total_peers': 0, - 'active_peers': 0, - 'inactive_peers': 0, - 'ip_ranges': {}, - 'error': str(e), - 'timestamp': datetime.utcnow().isoformat() +#!/usr/bin/env python3 +""" +Peer Registry for Personal Internet Cell +Handles peer registration and management +""" + +import json +import os +import logging +from threading import RLock +from datetime import datetime +from typing import Dict, List, Any, Optional +from base_service_manager import BaseServiceManager + +logger = logging.getLogger(__name__) + +class PeerRegistry(BaseServiceManager): + """Manages peer registration and management""" + + def __init__(self, data_dir: str = '/app/data', config_dir: str = '/app/config'): + super().__init__('peer_registry', data_dir, config_dir) + self.lock = RLock() + self.peers = [] + self.peers_file = os.path.join(data_dir, 'peers.json') + self._load_peers() + + def get_status(self) -> Dict[str, Any]: + """Get peer registry status""" + try: + with self.lock: + status = { + 'running': True, + 'status': 'online', + 'peers_count': len(self.peers), + 'active_peers': len([p for p in self.peers if p.get('active', True)]), + 'inactive_peers': len([p for p in self.peers if not p.get('active', True)]), + 'last_updated': datetime.utcnow().isoformat(), + 'timestamp': datetime.utcnow().isoformat() + } + + return status + except Exception as e: + return self.handle_error(e, "get_status") + + def test_connectivity(self) -> Dict[str, Any]: + """Test peer registry connectivity""" + try: + # Test file system access + fs_test = self._test_filesystem_access() + + # Test peer data integrity + integrity_test = self._test_data_integrity() + + # Test peer operations + operations_test = self._test_peer_operations() + + results = { + 'filesystem_access': fs_test, + 'data_integrity': integrity_test, + 'peer_operations': operations_test, + 'success': fs_test.get('success', False) and integrity_test.get('success', False), + 'timestamp': datetime.utcnow().isoformat() + } + + return results + except Exception as e: + return self.handle_error(e, "test_connectivity") + + def _test_filesystem_access(self) -> Dict[str, Any]: + """Test filesystem access for peer data""" + try: + # Test if we can read/write to the peers file + test_peer = { + 'peer': 'test_peer', + 'ip': '192.168.1.100', + 'public_key': 'test_key', + 'active': False, + 'test': True + } + + # Test write + with self.lock: + original_peers = self.peers.copy() + self.peers.append(test_peer) + self._save_peers() + + # Test read + with self.lock: + loaded_peers = self.peers.copy() + # Remove test peer + self.peers = [p for p in self.peers if not p.get('test', False)] + self._save_peers() + + # Restore original state + with self.lock: + self.peers = original_peers + self._save_peers() + + return { + 'success': True, + 'message': 'Filesystem access working', + 'read_write': True + } + except Exception as e: + return { + 'success': False, + 'message': f'Filesystem access failed: {str(e)}', + 'error': str(e) + } + + def _test_data_integrity(self) -> Dict[str, Any]: + """Test peer data integrity""" + try: + with self.lock: + # Check if peers data is valid JSON + peers_copy = self.peers.copy() + + # Validate peer structure + valid_peers = 0 + invalid_peers = 0 + + for peer in peers_copy: + if isinstance(peer, dict) and 'peer' in peer and 'ip' in peer: + valid_peers += 1 + else: + invalid_peers += 1 + + return { + 'success': True, + 'message': 'Data integrity check passed', + 'valid_peers': valid_peers, + 'invalid_peers': invalid_peers, + 'total_peers': len(peers_copy) + } + except Exception as e: + return { + 'success': False, + 'message': f'Data integrity check failed: {str(e)}', + 'error': str(e) + } + + def _test_peer_operations(self) -> Dict[str, Any]: + """Test peer operations""" + try: + # Test adding a peer + test_peer = { + 'peer': 'test_operation_peer', + 'ip': '192.168.1.101', + 'public_key': 'test_operation_key', + 'active': False, + 'test': True + } + + # Test add + add_success = self.add_peer(test_peer) + + # Test get + retrieved_peer = self.get_peer('test_operation_peer') + get_success = retrieved_peer is not None + + # Test update + update_success = self.update_peer_ip('test_operation_peer', '192.168.1.102') + + # Test remove + remove_success = self.remove_peer('test_operation_peer') + + return { + 'success': add_success and get_success and update_success and remove_success, + 'message': 'Peer operations working', + 'add_success': add_success, + 'get_success': get_success, + 'update_success': update_success, + 'remove_success': remove_success + } + except Exception as e: + return { + 'success': False, + 'message': f'Peer operations test failed: {str(e)}', + 'error': str(e) + } + + def _load_peers(self): + """Load peers from file""" + try: + # Ensure directory exists + os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) + + if os.path.exists(self.peers_file): + with open(self.peers_file, 'r') as f: + try: + self.peers = json.load(f) + self.logger.info(f"Loaded {len(self.peers)} peers from file") + except Exception as e: + self.logger.error(f"Error loading peers: {e}") + self.peers = [] + else: + self.peers = [] + self.logger.info("No peers file found, starting with empty registry") + except Exception as e: + self.logger.error(f"Error in _load_peers: {e}") + self.peers = [] + + def _save_peers(self): + """Save peers to file""" + try: + # Ensure directory exists + os.makedirs(os.path.dirname(self.peers_file), exist_ok=True) + + # Write to a temp file with restrictive perms, then atomically replace. + # peers.json contains WireGuard private keys — must never be world-readable. + tmp_path = self.peers_file + '.tmp' + # Open with O_CREAT|O_WRONLY|O_TRUNC and mode 0o600 so the file is + # created with restrictive permissions from the very first byte. + fd = os.open(tmp_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + try: + with os.fdopen(fd, 'w') as f: + json.dump(self.peers, f, indent=2) + except Exception: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + # Ensure perms are 0o600 even if umask or prior file affected them. + os.chmod(tmp_path, 0o600) + os.replace(tmp_path, self.peers_file) + # Belt-and-braces: also chmod the destination in case it pre-existed + # with looser perms on a filesystem that preserves the destination's mode. + os.chmod(self.peers_file, 0o600) + + self.logger.info(f"Saved {len(self.peers)} peers to file") + except Exception as e: + self.logger.error(f"Error saving peers: {e}") + + def list_peers(self) -> List[Dict[str, Any]]: + """List all peers""" + with self.lock: + return list(self.peers) + + def get_peer(self, name: str) -> Optional[Dict[str, Any]]: + """Get a specific peer by name""" + with self.lock: + for peer in self.peers: + if peer.get('peer') == name: + return peer + return None + + def add_peer(self, peer_info: Dict[str, Any]) -> bool: + """Add a new peer""" + try: + with self.lock: + if self.get_peer(peer_info.get('peer')): + self.logger.warning(f"Peer {peer_info.get('peer')} already exists") + return False + + # Add timestamp + peer_info['created_at'] = datetime.utcnow().isoformat() + peer_info['active'] = peer_info.get('active', True) + + self.peers.append(peer_info) + self._save_peers() + + self.logger.info(f"Added peer: {peer_info.get('peer')}") + return True + except Exception as e: + self.logger.error(f"Error adding peer: {e}") + return False + + def remove_peer(self, name: str) -> bool: + """Remove a peer""" + try: + with self.lock: + before = len(self.peers) + self.peers = [p for p in self.peers if p.get('peer') != name] + self._save_peers() + + removed = len(self.peers) < before + if removed: + self.logger.info(f"Removed peer: {name}") + else: + self.logger.warning(f"Peer {name} not found for removal") + + return removed + except Exception as e: + self.logger.error(f"Error removing peer {name}: {e}") + return False + + def update_peer(self, name: str, fields: Dict[str, Any]) -> bool: + """Update arbitrary fields on a peer.""" + try: + with self.lock: + for peer in self.peers: + if peer.get('peer') == name: + peer.update(fields) + peer['updated_at'] = datetime.utcnow().isoformat() + self._save_peers() + self.logger.info(f"Updated peer {name}: {list(fields.keys())}") + return True + self.logger.warning(f"Peer {name} not found for update") + return False + except Exception as e: + self.logger.error(f"Error updating peer {name}: {e}") + return False + + def clear_reinstall_flag(self, name: str) -> bool: + """Clear the config_needs_reinstall flag after user downloads new config.""" + return self.update_peer(name, {'config_needs_reinstall': False}) + + def update_peer_ip(self, name: str, new_ip: str) -> bool: + """Update peer IP address""" + try: + with self.lock: + for peer in self.peers: + if peer.get('peer') == name: + old_ip = peer.get('ip') + peer['ip'] = new_ip + peer['updated_at'] = datetime.utcnow().isoformat() + self._save_peers() + + self.logger.info(f"Updated peer {name} IP from {old_ip} to {new_ip}") + return True + + self.logger.warning(f"Peer {name} not found for IP update") + return False + except Exception as e: + self.logger.error(f"Error updating peer {name} IP: {e}") + return False + + def get_peer_stats(self) -> Dict[str, Any]: + """Get peer registry statistics""" + try: + with self.lock: + active_peers = [p for p in self.peers if p.get('active', True)] + inactive_peers = [p for p in self.peers if not p.get('active', True)] + + # Count peers by IP range + ip_ranges = {} + for peer in self.peers: + ip = peer.get('ip', '') + if ip: + range_key = '.'.join(ip.split('.')[:3]) + '.0/24' + ip_ranges[range_key] = ip_ranges.get(range_key, 0) + 1 + + return { + 'total_peers': len(self.peers), + 'active_peers': len(active_peers), + 'inactive_peers': len(inactive_peers), + 'ip_ranges': ip_ranges, + 'timestamp': datetime.utcnow().isoformat() + } + except Exception as e: + self.logger.error(f"Error getting peer stats: {e}") + return { + 'total_peers': 0, + 'active_peers': 0, + 'inactive_peers': 0, + 'ip_ranges': {}, + 'error': str(e), + 'timestamp': datetime.utcnow().isoformat() } \ No newline at end of file diff --git a/api/routing_manager.py b/api/routing_manager.py index 024c151..fde7e44 100644 --- a/api/routing_manager.py +++ b/api/routing_manager.py @@ -224,6 +224,22 @@ class RoutingManager(BaseServiceManager): def add_exit_node(self, peer_name: str, peer_ip: str, allowed_domains: List[str] = None) -> bool: """Add exit node configuration""" + # Validation — peer_ip flows into `ip route add default via `; argv + # injection / shell-meta in name would reach iptables/ip via _apply_exit_node. + if not isinstance(peer_name, str) or not re.match(r'^[a-zA-Z0-9_.-]{1,64}$', peer_name): + logger.error(f"add_exit_node: invalid peer_name {peer_name!r}") + return {'success': False, 'error': f'invalid input: peer_name {peer_name!r}'} + try: + ipaddress.ip_address(peer_ip) + except (ValueError, TypeError): + logger.error(f"add_exit_node: invalid peer_ip {peer_ip!r}") + return {'success': False, 'error': f'invalid input: peer_ip {peer_ip!r}'} + if allowed_domains is not None: + if not isinstance(allowed_domains, list): + return {'success': False, 'error': 'invalid input: allowed_domains must be a list'} + for d in allowed_domains: + if not isinstance(d, str) or not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9.-]{0,252}$', d): + return {'success': False, 'error': f'invalid input: domain {d!r}'} try: rules = self._load_rules() @@ -251,6 +267,23 @@ class RoutingManager(BaseServiceManager): def add_bridge_route(self, source_peer: str, target_peer: str, allowed_networks: List[str]) -> bool: """Add bridge route between peers""" + # source_peer is a name label; target_peer flows into iptables `-d` so must be + # an IP/CIDR. allowed_networks flows into iptables `-s` so must all be CIDRs. + if not isinstance(source_peer, str) or not re.match(r'^[a-zA-Z0-9_.-]{1,64}$', source_peer): + logger.error(f"add_bridge_route: invalid source_peer {source_peer!r}") + return {'success': False, 'error': f'invalid input: source_peer {source_peer!r}'} + try: + ipaddress.ip_network(target_peer, strict=False) + except (ValueError, TypeError): + logger.error(f"add_bridge_route: invalid target_peer {target_peer!r}") + return {'success': False, 'error': f'invalid input: target_peer must be IP/CIDR, got {target_peer!r}'} + if not isinstance(allowed_networks, list) or not allowed_networks: + return {'success': False, 'error': 'invalid input: allowed_networks must be a non-empty list'} + for n in allowed_networks: + try: + ipaddress.ip_network(n, strict=False) + except (ValueError, TypeError): + return {'success': False, 'error': f'invalid input: network {n!r}'} try: rules = self._load_rules() @@ -279,6 +312,22 @@ class RoutingManager(BaseServiceManager): def add_split_route(self, network: str, exit_peer: str, fallback_peer: str = None) -> bool: """Add split routing rule""" + # network flows into `ip route add `; exit_peer flows into `via `. + try: + ipaddress.ip_network(network, strict=False) + except (ValueError, TypeError): + logger.error(f"add_split_route: invalid network {network!r}") + return {'success': False, 'error': f'invalid input: network {network!r}'} + try: + ipaddress.ip_address(exit_peer) + except (ValueError, TypeError): + logger.error(f"add_split_route: invalid exit_peer {exit_peer!r}") + return {'success': False, 'error': f'invalid input: exit_peer must be an IP, got {exit_peer!r}'} + if fallback_peer is not None: + try: + ipaddress.ip_address(fallback_peer) + except (ValueError, TypeError): + return {'success': False, 'error': f'invalid input: fallback_peer must be an IP, got {fallback_peer!r}'} try: rules = self._load_rules() diff --git a/api/vault_manager.py b/api/vault_manager.py index 104b94b..58621a0 100644 --- a/api/vault_manager.py +++ b/api/vault_manager.py @@ -162,10 +162,26 @@ class VaultManager(BaseServiceManager): if self.fernet_key_file.exists(): with open(self.fernet_key_file, "rb") as f: self.fernet_key = f.read() + # SECURITY: ensure key file is owner-only readable on every load + # in case it was created with looser perms by an older version. + try: + os.chmod(str(self.fernet_key_file), 0o600) + except OSError: + pass else: self.fernet_key = Fernet.generate_key() - with open(self.fernet_key_file, "wb") as f: + # SECURITY: create the key file with 0o600 from the first byte + # so the secret is never world-readable, even momentarily. + fd = os.open( + str(self.fernet_key_file), + os.O_WRONLY | os.O_CREAT | os.O_TRUNC, + 0o600, + ) + with os.fdopen(fd, "wb") as f: f.write(self.fernet_key) + # Belt-and-braces chmod in case umask or a pre-existing file + # left wider permissions in place. + os.chmod(str(self.fernet_key_file), 0o600) self.fernet = Fernet(self.fernet_key) except (PermissionError, OSError): self.fernet_key = Fernet.generate_key() diff --git a/api/wireguard_manager.py b/api/wireguard_manager.py index 8e32666..473c6aa 100644 --- a/api/wireguard_manager.py +++ b/api/wireguard_manager.py @@ -459,12 +459,38 @@ class WireGuardManager(BaseServiceManager): Unlike add_peer(), allows a subnet CIDR as AllowedIPs (whole remote VPN range). The endpoint is expected to already include the port (e.g. '1.2.3.4:51820'). """ - import ipaddress + import ipaddress, re as _re + # Validate public_key strictly — empty/garbled keys later cause remove_peer("") + # to wipe ALL peer blocks via substring match. + if not isinstance(public_key, str) or not _re.match(r'^[A-Za-z0-9+/]{43}=$', public_key.strip()): + logger.error(f'add_cell_peer: invalid public_key') + return False + # Validate name — reject newlines/brackets that could inject config blocks + if not isinstance(name, str) or not _re.match(r'^[A-Za-z0-9_. -]{1,64}$', name): + logger.error(f'add_cell_peer: invalid name {name!r}') + return False + # Validate endpoint as host:port — reject newlines and out-of-range ports + if endpoint: + if not isinstance(endpoint, str) or not _re.match(r'^[A-Za-z0-9._-]+:\d{1,5}$', endpoint): + logger.error(f'add_cell_peer: invalid endpoint {endpoint!r}') + return False + try: + _port = int(endpoint.rsplit(':', 1)[1]) + if not (1 <= _port <= 65535): + logger.error(f'add_cell_peer: endpoint port out of range: {endpoint!r}') + return False + except (ValueError, IndexError): + logger.error(f'add_cell_peer: invalid endpoint port: {endpoint!r}') + return False try: ipaddress.ip_network(vpn_subnet, strict=False) except ValueError as e: logger.error(f'add_cell_peer: invalid vpn_subnet {vpn_subnet!r}: {e}') return False + # Reject any whitespace/newlines in vpn_subnet that ip_network() may have tolerated + if any(c.isspace() for c in vpn_subnet): + logger.error(f'add_cell_peer: vpn_subnet contains whitespace: {vpn_subnet!r}') + return False try: content = self._read_config() peer_block = ( @@ -531,6 +557,16 @@ class WireGuardManager(BaseServiceManager): def update_peer_ip(self, public_key: str, new_ip: str) -> bool: """Update AllowedIPs for the peer with the given public key.""" + import ipaddress + # Reject whitespace/newlines that ip_network() may tolerate but would inject config + if not isinstance(new_ip, str) or any(c.isspace() for c in new_ip): + logger.error(f'update_peer_ip: invalid new_ip {new_ip!r}') + return False + try: + ipaddress.ip_network(new_ip, strict=False) + except ValueError as e: + logger.error(f'update_peer_ip: invalid new_ip {new_ip!r}: {e}') + return False content = self._read_config() if f'PublicKey = {public_key}' not in content: return False @@ -737,6 +773,25 @@ class WireGuardManager(BaseServiceManager): status = self.get_status() running = status.get('running', False) return {'success': running, 'reachable': running, 'status': status.get('status')} + # Validate target_ip — reject argv injection (any string starting with '-' would + # be parsed by ping as a flag) and any non-IP input. + import ipaddress + if not isinstance(peer_ip, str) or peer_ip.startswith('-'): + return { + 'peer_ip': peer_ip, + 'ping_success': False, + 'ping_output': '', + 'ping_error': 'invalid peer_ip', + } + try: + ipaddress.ip_address(peer_ip) + except ValueError: + return { + 'peer_ip': peer_ip, + 'ping_success': False, + 'ping_output': '', + 'ping_error': 'invalid peer_ip', + } try: result = subprocess.run( ['ping', '-c', '1', '-W', '2', peer_ip], diff --git a/config/api/.gitkeep b/config/api/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/api/calendar.json b/config/api/calendar.json new file mode 100644 index 0000000..81fbc2e --- /dev/null +++ b/config/api/calendar.json @@ -0,0 +1,3 @@ +{ + "port": 5233 +} \ No newline at end of file diff --git a/config/api/cell_config.json b/config/api/cell_config.json index 4eb11e6..1a4a6a6 100644 --- a/config/api/cell_config.json +++ b/config/api/cell_config.json @@ -1,7 +1,7 @@ { "_identity": { "cell_name": "pic0", - "domain": "lan", + "domain": "dec", "ip_range": "172.20.0.0/16", "wireguard_port": 51820 }, diff --git a/config/caddy/Caddyfile b/config/caddy/Caddyfile index f385c8a..22cbd0b 100644 --- a/config/caddy/Caddyfile +++ b/config/caddy/Caddyfile @@ -3,7 +3,7 @@ } # Main cell domain — no service-IP restriction needed -http://mycell.cell, http://172.20.0.2:80 { +http://pic0.dec, http://172.20.0.2:80 { handle /api/* { reverse_proxy cell-api:3000 } @@ -22,26 +22,30 @@ http://mycell.cell, http://172.20.0.2:80 { } # Per-service virtual IPs — each gets its own IP so iptables can target them -http://calendar.cell, http://172.20.0.21:80 { +http://calendar.dec, http://172.20.0.21:80 { reverse_proxy cell-radicale:5232 } -http://files.cell, http://172.20.0.22:80 { +http://files.dec, http://172.20.0.22:80 { reverse_proxy cell-filegator:8080 } -http://mail.cell, http://webmail.cell, http://172.20.0.23:80 { +http://mail.dec, http://webmail.dec, http://172.20.0.23:80 { reverse_proxy cell-rainloop:8888 } -http://webdav.cell, http://172.20.0.24:80 { +http://webdav.dec, http://172.20.0.24:80 { reverse_proxy cell-webdav:80 } -http://api.cell { +http://api.dec { reverse_proxy cell-api:3000 } +http://webui.dec { + reverse_proxy cell-webui:80 +} + # Catch-all for direct IP / localhost :80 { handle /api/* { diff --git a/config/dhcp/.gitkeep b/config/dhcp/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/dns/.gitkeep b/config/dns/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/dns/Corefile b/config/dns/Corefile index 92f151f..74bf426 100644 --- a/config/dns/Corefile +++ b/config/dns/Corefile @@ -5,8 +5,8 @@ health } -lan { - file /data/lan.zone +dec { + file /data/dec.zone log } diff --git a/config/ntp/.gitkeep b/config/ntp/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/webdav/.gitkeep b/config/webdav/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docker-compose.yml b/docker-compose.yml index fe1aa15..5706bb1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -199,6 +199,7 @@ services: - ./data/api:/app/data - ./data/dns:/app/data/dns - ./config/api:/app/config + - ./config/caddy:/app/config-caddy - ./config/wireguard:/app/config/wireguard - ./config/dns:/app/config/dns - ./data/logs:/app/api/data/logs diff --git a/tests/e2e/wg/conftest.py b/tests/e2e/wg/conftest.py index 359e01d..ac1f80b 100644 --- a/tests/e2e/wg/conftest.py +++ b/tests/e2e/wg/conftest.py @@ -1,10 +1,16 @@ import os +import shutil import pytest import tempfile import secrets from helpers.wg_runner import WGInterface, build_wg_config, cleanup_stale_e2e_interfaces +def pytest_configure(config): + if not shutil.which('wg-quick'): + pytest.skip('wg-quick not found — skipping WireGuard E2E tests', allow_module_level=True) + + @pytest.fixture(scope='session', autouse=True) def cleanup_stale_wg_interfaces(): cleanup_stale_e2e_interfaces() diff --git a/tests/e2e/wg/test_caddy_routing.py b/tests/e2e/wg/test_caddy_routing.py new file mode 100644 index 0000000..e4128ba --- /dev/null +++ b/tests/e2e/wg/test_caddy_routing.py @@ -0,0 +1,275 @@ +""" +WireGuard E2E: Caddy per-domain routing correctness. + +Scenarios covered: + 35. api. proxies to the API (returns JSON), not the WebUI + 36. calendar. via VIP proxies to Radicale, not the WebUI + 37. files. via VIP proxies to Filegator, not the WebUI + 38. mail. via VIP proxies to Rainloop, not the WebUI + 39. webdav. via VIP proxies to the WebDAV service, not the WebUI + 40. Direct VIP requests (by IP) go to the correct service + 41. Catch-all :80 serves WebUI for unknown hosts but routes /api/* to API + +The WebUI serves a React app — its HTML starts with ''. +Any service domain that returns that string is incorrectly falling through +to the catch-all :80 block instead of being routed by its Host header. + +These tests require a live PIC stack with WireGuard and are marked `wg`. +They run via `make test-e2e-wg` or `pytest tests/e2e/wg/ -m wg`. +""" + +import subprocess +import pytest + +pytestmark = pytest.mark.wg + +_WEBUI_MARKER = '' + + +def _config(admin_client) -> dict: + r = admin_client.get('/api/config') + return r.json() if r.status_code == 200 else {} + + +def _domain(admin_client) -> str: + return _config(admin_client).get('domain') or 'lan' + + +def _dns_ip(admin_client) -> str: + cfg = _config(admin_client) + return cfg.get('service_ips', {}).get('dns') or '172.20.0.3' + + +def _curl_host(ip: str, host: str, path: str = '/', timeout: int = 8) -> tuple[int, str]: + """ + Make an HTTP request to `ip` with the given Host header. + Returns (http_code, body_snippet). + """ + result = subprocess.run( + ['curl', '-s', '--connect-timeout', '5', + '-H', f'Host: {host}', + '-w', '\n__HTTP_CODE__:%{http_code}', + f'http://{ip}{path}'], + capture_output=True, text=True, timeout=timeout, + ) + output = result.stdout + body = '' + code = 0 + if '__HTTP_CODE__:' in output: + parts = output.rsplit('__HTTP_CODE__:', 1) + body = parts[0].lower() + try: + code = int(parts[1].strip()) + except ValueError: + pass + return code, body + + +def _curl_domain(host: str, path: str = '/', dns_ip: str = '', timeout: int = 8) -> tuple[int, str]: + """Make an HTTP request using curl's --dns-servers to resolve via CoreDNS.""" + cmd = ['curl', '-s', '--connect-timeout', '5', + '-w', '\n__HTTP_CODE__:%{http_code}', + f'http://{host}{path}'] + if dns_ip: + cmd = ['curl', '-s', '--connect-timeout', '5', + '--dns-servers', dns_ip, + '-w', '\n__HTTP_CODE__:%{http_code}', + f'http://{host}{path}'] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + output = result.stdout + body = '' + code = 0 + if '__HTTP_CODE__:' in output: + parts = output.rsplit('__HTTP_CODE__:', 1) + body = parts[0].lower() + try: + code = int(parts[1].strip()) + except ValueError: + pass + return code, body + + +# ── Scenario 35: api. routes to API ─────────────────────────────────── + +def test_api_domain_returns_json_not_webui(connected_peer, admin_client): + """api./api/status must return JSON, not the React WebUI HTML.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'api.{dom}', '/api/status', dns_ip) + assert code not in (0, 000), f"curl to api.{dom}/api/status failed (code {code})" + assert _WEBUI_MARKER not in body, ( + f"api.{dom}/api/status returned WebUI HTML — " + "Caddy is not routing api. to the API; " + "check that the http://api. block exists in the Caddyfile " + "and uses the configured domain (not a stale .cell or .dev TLD)" + ) + assert '{' in body or '"' in body, ( + f"api.{dom}/api/status did not return JSON (body: {body[:100]!r})" + ) + + +def test_api_vip_host_header_routes_to_api(connected_peer, admin_client): + """Caddy routes api. by Host header even when accessed via the Caddy VIP.""" + dom = _domain(admin_client) + code, body = _curl_host('172.20.0.2', f'api.{dom}', '/api/status') + assert _WEBUI_MARKER not in body, ( + f"Host: api.{dom} via 172.20.0.2 returned WebUI HTML — " + "Caddy http://api. block is missing or uses wrong TLD" + ) + + +# ── Scenario 36: calendar. routes to Radicale ──────────────────────── + +def test_calendar_vip_does_not_serve_webui(connected_peer, admin_client): + """calendar. (VIP 172.20.0.21) must proxy to Radicale, not the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'calendar.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to calendar.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"calendar.{dom} returned WebUI HTML — " + "Caddy is not routing calendar. to Radicale. " + "This happens when Caddy has old (e.g. .cell) domain blocks and all " + "traffic falls through to the catch-all :80 block." + ) + + +def test_calendar_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.21 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.21', 'calendar.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.21 (calendar VIP) returned WebUI HTML — " + "Caddy http://calendar., http://172.20.0.21:80 block is missing or stale" + ) + + +# ── Scenario 37: files. routes to Filegator ────────────────────────── + +def test_files_vip_does_not_serve_webui(connected_peer, admin_client): + """files. (VIP 172.20.0.22) must proxy to Filegator, not the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'files.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to files.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"files.{dom} returned WebUI HTML — " + "Caddy is not routing files. to Filegator. " + "Check the http://files., http://172.20.0.22:80 Caddyfile block." + ) + + +def test_files_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.22 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.22', 'files.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.22 (files VIP) returned WebUI HTML — " + "Caddy http://files., http://172.20.0.22:80 block is missing or stale" + ) + + +# ── Scenario 38: mail. routes to Rainloop ──────────────────────────── + +def test_mail_vip_does_not_serve_webui(connected_peer, admin_client): + """mail. (VIP 172.20.0.23) must proxy to Rainloop, not the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'mail.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to mail.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"mail.{dom} returned WebUI HTML — " + "Caddy is not routing mail. to Rainloop." + ) + + +def test_webmail_vip_does_not_serve_webui(connected_peer, admin_client): + """webmail. (alias, same VIP 172.20.0.23) must NOT return the WebUI.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'webmail.{dom}', '/', dns_ip) + assert _WEBUI_MARKER not in body, ( + f"webmail.{dom} returned WebUI HTML — " + "Caddy http://webmail. block is missing or stale" + ) + + +def test_mail_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.23 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.23', 'mail.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.23 (mail VIP) returned WebUI HTML — " + "Caddy http://mail., http://172.20.0.23:80 block is missing or stale" + ) + + +# ── Scenario 39: webdav. routes to WebDAV ──────────────────────────── + +def test_webdav_vip_does_not_serve_webui(connected_peer, admin_client): + """webdav. (VIP 172.20.0.24) must proxy to the WebDAV service.""" + dom = _domain(admin_client) + dns_ip = _dns_ip(admin_client) + code, body = _curl_domain(f'webdav.{dom}', '/', dns_ip) + assert code not in (0,), f"curl to webdav.{dom} failed completely" + assert _WEBUI_MARKER not in body, ( + f"webdav.{dom} returned WebUI HTML — " + "Caddy is not routing webdav. to the WebDAV service." + ) + + +def test_webdav_vip_ip_does_not_serve_webui(connected_peer): + """Direct request to VIP 172.20.0.24 must NOT return the WebUI.""" + code, body = _curl_host('172.20.0.24', 'webdav.lan') + assert _WEBUI_MARKER not in body, ( + "172.20.0.24 (webdav VIP) returned WebUI HTML — " + "Caddy http://webdav., http://172.20.0.24:80 block is missing or stale" + ) + + +# ── Scenario 40: VIP IPs without Host header ───────────────────────────────── + +@pytest.mark.parametrize('vip,expected_not', [ + ('172.20.0.21', _WEBUI_MARKER), + ('172.20.0.22', _WEBUI_MARKER), + ('172.20.0.23', _WEBUI_MARKER), + ('172.20.0.24', _WEBUI_MARKER), +]) +def test_vip_direct_access_not_webui(connected_peer, vip, expected_not): + """Each service VIP accessed directly (no special Host) must not return WebUI.""" + code, body = _curl_host(vip, vip) + assert expected_not not in body, ( + f"VIP {vip} returned WebUI HTML — " + "Caddy catch-all :80 is taking over; the per-VIP blocks must listen on port 80" + ) + + +# ── Scenario 41: Catch-all :80 routes API path correctly ───────────────────── + +def test_catchall_api_path_returns_json(connected_peer): + """The catch-all :80 block must route /api/* to the API (not WebUI).""" + code, body = _curl_host('172.20.0.2', 'localhost', '/api/status') + assert _WEBUI_MARKER not in body, ( + "Catch-all :80 returned WebUI HTML for /api/status — " + "the `handle /api/*` directive in the :80 block is missing or wrong" + ) + assert '{' in body or '"' in body, ( + f"/api/status via catch-all did not return JSON (body: {body[:100]!r})" + ) + + +def test_catchall_root_serves_webui(connected_peer): + """The catch-all :80 block serves the WebUI for the root path.""" + code, body = _curl_host('172.20.0.2', 'localhost', '/') + assert _WEBUI_MARKER in body, ( + "Catch-all :80 / did not return WebUI HTML — " + "something is broken with the catch-all :80 block" + ) + + +# ── Scenario extra: stale TLD detection ────────────────────────────────────── + +def test_caddy_does_not_route_cell_tld(connected_peer): + """Caddy must NOT have active routing for .cell domains — they are from old config.""" + code, body = _curl_host('172.20.0.2', 'calendar.cell', '/') + assert _WEBUI_MARKER in body or code in (0, 404, 502, 503), ( + "Caddy is still routing calendar.cell — stale .cell blocks remain in config. " + "Check that write_caddyfile() is writing to the correct path that Caddy reads." + ) diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py index 9ea6c9c..6e7ee71 100644 --- a/tests/test_api_endpoints.py +++ b/tests/test_api_endpoints.py @@ -366,8 +366,8 @@ class TestAPIEndpoints(unittest.TestCase): def test_email_endpoints(self, mock_email): # Ensure all relevant mock methods return JSON-serializable values mock_email.get_users.return_value = [{'username': 'user1', 'domain': 'cell', 'email': 'user1@cell'}] - mock_email.create_user.return_value = True - mock_email.delete_user.return_value = True + mock_email.create_email_user.return_value = True + mock_email.delete_email_user.return_value = True mock_email.get_status.return_value = {'postfix_running': True, 'dovecot_running': True, 'total_users': 1, 'total_size_bytes': 0, 'total_size_mb': 0.0, 'users': [{'username': 'user1', 'domain': 'cell', 'email': 'user1@cell'}]} mock_email.test_connectivity.return_value = {'smtp': {'success': True, 'message': 'SMTP server responding'}} mock_email.send_email.return_value = True @@ -383,17 +383,17 @@ class TestAPIEndpoints(unittest.TestCase): # /api/email/users (POST) response = self.client.post('/api/email/users', data=json.dumps({'username': 'user1', 'domain': 'cell', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 200) - mock_email.create_user.side_effect = Exception('fail') + mock_email.create_email_user.side_effect = Exception('fail') response = self.client.post('/api/email/users', data=json.dumps({'username': 'user1', 'domain': 'cell', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 500) - mock_email.create_user.side_effect = None + mock_email.create_email_user.side_effect = None # /api/email/users/ (DELETE) response = self.client.delete('/api/email/users/user1') self.assertEqual(response.status_code, 200) - mock_email.delete_user.side_effect = Exception('fail') + mock_email.delete_email_user.side_effect = Exception('fail') response = self.client.delete('/api/email/users/user1') self.assertEqual(response.status_code, 500) - mock_email.delete_user.side_effect = None + mock_email.delete_email_user.side_effect = None # /api/email/status (GET) response = self.client.get('/api/email/status') self.assertEqual(response.status_code, 200) @@ -427,8 +427,8 @@ class TestAPIEndpoints(unittest.TestCase): def test_calendar_endpoints(self, mock_calendar): # Mock return values for all relevant calendar_manager methods mock_calendar.get_users.return_value = [{'username': 'user1', 'collections': {'calendars': ['cal1'], 'contacts': ['c1']}}] - mock_calendar.create_user.return_value = True - mock_calendar.delete_user.return_value = True + mock_calendar.create_calendar_user.return_value = True + mock_calendar.delete_calendar_user.return_value = True mock_calendar.create_calendar.return_value = {'calendar': 'cal1'} mock_calendar.add_event.return_value = {'event': 'event1'} mock_calendar.get_events.return_value = [{'event': 'event1'}] @@ -445,17 +445,17 @@ class TestAPIEndpoints(unittest.TestCase): # /api/calendar/users (POST) response = self.client.post('/api/calendar/users', data=json.dumps({'username': 'user1', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 200) - mock_calendar.create_user.side_effect = Exception('fail') + mock_calendar.create_calendar_user.side_effect = Exception('fail') response = self.client.post('/api/calendar/users', data=json.dumps({'username': 'user1', 'password': 'pw'}), content_type='application/json') self.assertEqual(response.status_code, 500) - mock_calendar.create_user.side_effect = None + mock_calendar.create_calendar_user.side_effect = None # /api/calendar/users/ (DELETE) response = self.client.delete('/api/calendar/users/user1') self.assertEqual(response.status_code, 200) - mock_calendar.delete_user.side_effect = Exception('fail') + mock_calendar.delete_calendar_user.side_effect = Exception('fail') response = self.client.delete('/api/calendar/users/user1') self.assertEqual(response.status_code, 500) - mock_calendar.delete_user.side_effect = None + mock_calendar.delete_calendar_user.side_effect = None # /api/calendar/calendars (POST) response = self.client.post('/api/calendar/calendars', data=json.dumps({'username': 'user1', 'calendar_name': 'cal1'}), content_type='application/json') self.assertEqual(response.status_code, 200) @@ -599,10 +599,10 @@ class TestAPIEndpoints(unittest.TestCase): self.assertEqual(response.status_code, 500) mock_routing.get_firewall_rules.side_effect = None # /api/routing/peers (POST) - response = self.client.post('/api/routing/peers', data=json.dumps({'peer': 'peer1', 'route': '10.0.0.2'}), content_type='application/json') + response = self.client.post('/api/routing/peers', data=json.dumps({'peer_name': 'peer1', 'peer_ip': '10.0.0.2'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_peer_route.side_effect = Exception('fail') - response = self.client.post('/api/routing/peers', data=json.dumps({'peer': 'peer1', 'route': '10.0.0.2'}), content_type='application/json') + response = self.client.post('/api/routing/peers', data=json.dumps({'peer_name': 'peer1', 'peer_ip': '10.0.0.2'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_peer_route.side_effect = None # /api/routing/peers (GET) @@ -620,24 +620,24 @@ class TestAPIEndpoints(unittest.TestCase): self.assertEqual(response.status_code, 500) mock_routing.remove_peer_route.side_effect = None # /api/routing/exit-nodes (POST) - response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'node': 'exit1'}), content_type='application/json') + response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'peer_name': 'exit1', 'peer_ip': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_exit_node.side_effect = Exception('fail') - response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'node': 'exit1'}), content_type='application/json') + response = self.client.post('/api/routing/exit-nodes', data=json.dumps({'peer_name': 'exit1', 'peer_ip': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_exit_node.side_effect = None # /api/routing/bridge (POST) - response = self.client.post('/api/routing/bridge', data=json.dumps({'bridge': 'br1'}), content_type='application/json') + response = self.client.post('/api/routing/bridge', data=json.dumps({'source_peer': 'peer1', 'target_peer': 'peer2'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_bridge_route.side_effect = Exception('fail') - response = self.client.post('/api/routing/bridge', data=json.dumps({'bridge': 'br1'}), content_type='application/json') + response = self.client.post('/api/routing/bridge', data=json.dumps({'source_peer': 'peer1', 'target_peer': 'peer2'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_bridge_route.side_effect = None # /api/routing/split (POST) - response = self.client.post('/api/routing/split', data=json.dumps({'split': 'sp1'}), content_type='application/json') + response = self.client.post('/api/routing/split', data=json.dumps({'network': '10.0.0.0/24', 'exit_peer': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 200) mock_routing.add_split_route.side_effect = Exception('fail') - response = self.client.post('/api/routing/split', data=json.dumps({'split': 'sp1'}), content_type='application/json') + response = self.client.post('/api/routing/split', data=json.dumps({'network': '10.0.0.0/24', 'exit_peer': '10.0.0.5'}), content_type='application/json') self.assertEqual(response.status_code, 500) mock_routing.add_split_route.side_effect = None # /api/routing/connectivity (POST) diff --git a/tests/test_app_misc.py b/tests/test_app_misc.py index e326921..8b6e8c5 100644 --- a/tests/test_app_misc.py +++ b/tests/test_app_misc.py @@ -113,8 +113,11 @@ class TestAppMisc(unittest.TestCase): self.assertFalse(app_module.is_local_request()) def test_is_local_request_private_ip(self): + # 192.168.x.x (LAN) is no longer trusted — only Docker bridge (172.16.0.0/12) + # and loopback are trusted. The API is bound to 127.0.0.1:3000 and only + # reachable via Caddy (172.20.x.x), so LAN IPs never reach it directly. with patch('app.request', new=self._req('192.168.1.5')): - self.assertTrue(app_module.is_local_request()) + self.assertFalse(app_module.is_local_request()) def test_is_local_request_xff_spoof_rejected(self): # Client sends X-Forwarded-For: 127.0.0.1 but actual IP is public @@ -123,8 +126,14 @@ class TestAppMisc(unittest.TestCase): self.assertFalse(app_module.is_local_request()) def test_is_local_request_xff_last_entry_local(self): - # Caddy appends the real client IP; last entry is local → allow + # 192.168.x.x is no longer in the trusted range — only Docker bridge + # (172.16.0.0/12) and loopback are trusted now. with patch('app.request', new=self._req('8.8.8.8', xff='8.8.8.8, 192.168.1.10')): + self.assertFalse(app_module.is_local_request()) + + def test_is_local_request_xff_docker_bridge(self): + # Docker bridge IPs (172.16.0.0/12) ARE trusted — Caddy uses this range + with patch('app.request', new=self._req('8.8.8.8', xff='8.8.8.8, 172.20.0.2')): self.assertTrue(app_module.is_local_request()) def test_is_local_request_xff_single_public_rejected(self): diff --git a/tests/test_calendar_endpoints.py b/tests/test_calendar_endpoints.py index 1f55537..b5ea38c 100644 --- a/tests/test_calendar_endpoints.py +++ b/tests/test_calendar_endpoints.py @@ -1 +1,379 @@ -# ... moved and adapted code from test_phase3_endpoints.py (calendar section) ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for calendar Flask endpoints in api/app.py. + +Covers: + GET /api/calendar/users + POST /api/calendar/users + DELETE /api/calendar/users/ + POST /api/calendar/calendars + POST /api/calendar/events + GET /api/calendar/events// + GET /api/calendar/status + GET /api/calendar/connectivity +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetCalendarUsers(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_get_users_returns_200_with_list(self, mock_cm): + mock_cm.get_users.return_value = [ + {'username': 'alice', 'email': 'alice@cell'}, + {'username': 'bob', 'email': 'bob@cell'}, + ] + r = self.client.get('/api/calendar/users') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.calendar_manager') + def test_get_users_returns_200_with_empty_list(self, mock_cm): + mock_cm.get_users.return_value = [] + r = self.client.get('/api/calendar/users') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.calendar_manager') + def test_get_users_returns_500_on_exception(self, mock_cm): + mock_cm.get_users.side_effect = Exception('radicale unreachable') + r = self.client.get('/api/calendar/users') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCreateCalendarUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_create_user_returns_200_on_valid_body(self, mock_cm): + mock_cm.create_calendar_user.return_value = True + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.calendar_manager') + def test_create_user_passes_credentials_to_manager(self, mock_cm): + mock_cm.create_calendar_user.return_value = True + self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + mock_cm.create_calendar_user.assert_called_once_with('alice', 'secret123') + + @patch('app.calendar_manager') + def test_create_user_returns_400_when_no_body(self, mock_cm): + r = self.client.post('/api/calendar/users') + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + mock_cm.create_calendar_user.assert_not_called() + + @patch('app.calendar_manager') + def test_create_user_returns_400_when_username_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.create_calendar_user.assert_not_called() + + @patch('app.calendar_manager') + def test_create_user_returns_400_when_password_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.create_calendar_user.assert_not_called() + + @patch('app.calendar_manager') + def test_create_user_returns_500_on_exception(self, mock_cm): + mock_cm.create_calendar_user.side_effect = Exception('htpasswd write failure') + r = self.client.post( + '/api/calendar/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteCalendarUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_delete_user_returns_200_on_success(self, mock_cm): + mock_cm.delete_calendar_user.return_value = True + r = self.client.delete('/api/calendar/users/alice') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('deleted', data) + + @patch('app.calendar_manager') + def test_delete_user_passes_username_to_manager(self, mock_cm): + mock_cm.delete_calendar_user.return_value = True + self.client.delete('/api/calendar/users/bob') + mock_cm.delete_calendar_user.assert_called_once_with('bob') + + @patch('app.calendar_manager') + def test_delete_user_returns_500_on_exception(self, mock_cm): + mock_cm.delete_calendar_user.side_effect = Exception('user not found') + r = self.client.delete('/api/calendar/users/alice') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCreateCalendar(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_create_calendar_returns_200_on_valid_body(self, mock_cm): + mock_cm.create_calendar.return_value = True + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice', 'name': 'Work'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.calendar_manager') + def test_create_calendar_accepts_calendar_name_alias(self, mock_cm): + mock_cm.create_calendar.return_value = True + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice', 'calendar_name': 'Personal'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + + @patch('app.calendar_manager') + def test_create_calendar_returns_400_when_no_body(self, mock_cm): + r = self.client.post('/api/calendar/calendars') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.create_calendar.assert_not_called() + + @patch('app.calendar_manager') + def test_create_calendar_returns_400_when_username_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'name': 'Work'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_create_calendar_returns_400_when_name_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_create_calendar_returns_500_on_exception(self, mock_cm): + mock_cm.create_calendar.side_effect = Exception('CalDAV error') + r = self.client.post( + '/api/calendar/calendars', + data=json.dumps({'username': 'alice', 'name': 'Work'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddCalendarEvent(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_add_event_returns_200_on_valid_body(self, mock_cm): + mock_cm.add_event.return_value = 'event-uid-123' + r = self.client.post( + '/api/calendar/events', + data=json.dumps({ + 'username': 'alice', + 'calendar_name': 'Work', + 'summary': 'Team Meeting', + 'dtstart': '20260427T100000Z', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.calendar_manager') + def test_add_event_returns_400_when_no_body(self, mock_cm): + r = self.client.post('/api/calendar/events') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_cm.add_event.assert_not_called() + + @patch('app.calendar_manager') + def test_add_event_returns_400_when_username_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/events', + data=json.dumps({'calendar_name': 'Work', 'summary': 'Meeting'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_add_event_returns_400_when_calendar_missing(self, mock_cm): + r = self.client.post( + '/api/calendar/events', + data=json.dumps({'username': 'alice', 'summary': 'Meeting'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.calendar_manager') + def test_add_event_returns_500_on_exception(self, mock_cm): + mock_cm.add_event.side_effect = Exception('iCalendar parse error') + r = self.client.post( + '/api/calendar/events', + data=json.dumps({ + 'username': 'alice', + 'calendar_name': 'Work', + 'summary': 'Meeting', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetCalendarEvents(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_get_events_returns_200_with_events(self, mock_cm): + mock_cm.get_events.return_value = [ + {'uid': 'abc', 'summary': 'Standup', 'dtstart': '20260427T090000Z'}, + ] + r = self.client.get('/api/calendar/events/alice/Work') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 1) + + @patch('app.calendar_manager') + def test_get_events_passes_username_and_calendar_to_manager(self, mock_cm): + mock_cm.get_events.return_value = [] + self.client.get('/api/calendar/events/bob/Personal') + mock_cm.get_events.assert_called_once() + args = mock_cm.get_events.call_args[0] + self.assertEqual(args[0], 'bob') + self.assertEqual(args[1], 'Personal') + + @patch('app.calendar_manager') + def test_get_events_returns_500_on_exception(self, mock_cm): + mock_cm.get_events.side_effect = Exception('calendar not found') + r = self.client.get('/api/calendar/events/alice/Work') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetCalendarStatus(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_get_status_returns_200_with_status_dict(self, mock_cm): + mock_cm.get_status.return_value = { + 'running': True, + 'port': 5232, + 'users_count': 3, + } + r = self.client.get('/api/calendar/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('running', data) + + @patch('app.calendar_manager') + def test_get_status_returns_500_on_exception(self, mock_cm): + mock_cm.get_status.side_effect = Exception('container not found') + r = self.client.get('/api/calendar/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCalendarConnectivity(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.calendar_manager') + def test_connectivity_returns_200_with_result(self, mock_cm): + mock_cm.test_connectivity.return_value = { + 'caldav': True, + 'carddav': True, + 'latency_ms': 8, + } + r = self.client.get('/api/calendar/connectivity') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('caldav', data) + + @patch('app.calendar_manager') + def test_connectivity_returns_500_on_exception(self, mock_cm): + mock_cm.test_connectivity.side_effect = Exception('connection refused') + r = self.client.get('/api/calendar/connectivity') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_cell_link_dns.py b/tests/test_cell_link_dns.py new file mode 100644 index 0000000..297d5d8 --- /dev/null +++ b/tests/test_cell_link_dns.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +""" +Tests for cell-to-cell DNS forwarding integration. + +Covers: + - generate_corefile() with cell_links entries + - apply_all_dns_rules() passing cell_links through to generate_corefile() + - Correct domain/dns_ip values in the emitted forwarding stanza + - Validation: invalid characters in domain are rejected by add_cell_dns_forward() +""" + +import sys +import os +import tempfile +import shutil +import unittest +from unittest.mock import patch, MagicMock, call +from pathlib import Path + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +import firewall_manager + + +# --------------------------------------------------------------------------- +# generate_corefile() with cell_links +# --------------------------------------------------------------------------- + +class TestGenerateCorefileOneLink(unittest.TestCase): + """generate_corefile() with a single cell link produces the right stanza.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.path = os.path.join(self.tmp, 'Corefile') + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _read(self): + return open(self.path).read() + + def test_forwarding_block_present(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('remote.cell {', content) + + def test_correct_dns_ip_in_forward_directive(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('forward . 10.5.0.1', content) + + def test_cache_directive_present_in_forwarding_block(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + # 'cache' must appear in the forwarding block (after the primary zone block) + idx_primary = content.index('remote.cell {') + self.assertIn('cache', content[idx_primary:]) + + def test_log_directive_present_in_forwarding_block(self): + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + idx_primary = content.index('remote.cell {') + self.assertIn('log', content[idx_primary:]) + + def test_forwarding_block_appears_after_primary_zone(self): + """The cell link stanza must appear after the primary zone block, not inside it.""" + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.5.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + # Primary zone ends with its closing brace; remote.cell block follows + idx_primary_zone = content.index('cell {') + idx_forward_block = content.index('remote.cell {') + self.assertGreater(idx_forward_block, idx_primary_zone) + + +class TestGenerateCorefileMultipleLinks(unittest.TestCase): + """generate_corefile() with multiple cell links produces one stanza each.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.path = os.path.join(self.tmp, 'Corefile') + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _read(self): + return open(self.path).read() + + def test_all_domains_present(self): + cell_links = [ + {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, + {'domain': 'gamma.cell', 'dns_ip': '10.3.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('alpha.cell {', content) + self.assertIn('beta.cell {', content) + self.assertIn('gamma.cell {', content) + + def test_all_dns_ips_present(self): + cell_links = [ + {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + self.assertIn('forward . 10.1.0.1', content) + self.assertIn('forward . 10.2.0.1', content) + + def test_stanza_count_matches_link_count(self): + """Each valid link contributes exactly one forwarding stanza.""" + cell_links = [ + {'domain': 'a.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'b.cell', 'dns_ip': '10.2.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._read() + # Count occurrences of 'forward .' — one for default, one per cell link + count = content.count('forward .') + self.assertEqual(count, 3) # 1 default + 2 cell links + + +# --------------------------------------------------------------------------- +# apply_all_dns_rules() passes cell_links through to generate_corefile() +# --------------------------------------------------------------------------- + +class TestApplyAllDnsRulesPassesCellLinks(unittest.TestCase): + """apply_all_dns_rules() must forward the cell_links argument to generate_corefile().""" + + def test_cell_links_forwarded(self): + cell_links = [{'domain': 'x.cell', 'dns_ip': '10.9.0.1'}] + with patch.object(firewall_manager, 'generate_corefile', return_value=True) as mock_gen, \ + patch.object(firewall_manager, 'reload_coredns', return_value=True): + firewall_manager.apply_all_dns_rules( + peers=[], + corefile_path='/tmp/fake_Corefile', + domain='cell', + cell_links=cell_links, + ) + mock_gen.assert_called_once_with( + [], '/tmp/fake_Corefile', 'cell', cell_links + ) + + def test_cell_links_none_forwarded_as_none(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=True) as mock_gen, \ + patch.object(firewall_manager, 'reload_coredns', return_value=True): + firewall_manager.apply_all_dns_rules( + peers=[], + corefile_path='/tmp/fake_Corefile', + domain='cell', + cell_links=None, + ) + mock_gen.assert_called_once_with([], '/tmp/fake_Corefile', 'cell', None) + + def test_reload_called_on_success(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=True), \ + patch.object(firewall_manager, 'reload_coredns', return_value=True) as mock_reload: + firewall_manager.apply_all_dns_rules([], '/tmp/f', cell_links=None) + mock_reload.assert_called_once() + + def test_reload_not_called_on_failure(self): + with patch.object(firewall_manager, 'generate_corefile', return_value=False), \ + patch.object(firewall_manager, 'reload_coredns') as mock_reload: + firewall_manager.apply_all_dns_rules([], '/tmp/f', cell_links=None) + mock_reload.assert_not_called() + + +# --------------------------------------------------------------------------- +# Domain validation in add_cell_dns_forward() (via network_manager) +# --------------------------------------------------------------------------- + +class TestAddCellDnsForwardValidation(unittest.TestCase): + """ + add_cell_dns_forward() must reject malformed domains/IPs without writing + the Corefile or calling apply_all_dns_rules(). + """ + + def _get_network_manager(self, tmp_dir): + """Construct a minimal NetworkManager with test directories.""" + # We import here so the test file doesn't hard-fail if network_manager + # has an import-time dependency that's unavailable in CI. + try: + from network_manager import NetworkManager + except ImportError as e: + self.skipTest(f'NetworkManager import failed: {e}') + os.makedirs(os.path.join(tmp_dir, 'dns'), exist_ok=True) + return NetworkManager(data_dir=tmp_dir, config_dir=tmp_dir) + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp) + + def test_invalid_dns_ip_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('valid.cell', 'not-an-ip') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_domain_with_newline_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('evil\ndomain', '10.1.0.1') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_domain_with_braces_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('evil{domain}', '10.1.0.1') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_domain_with_space_returns_warning(self): + nm = self._get_network_manager(self.tmp) + result = nm.add_cell_dns_forward('evil domain', '10.1.0.1') + self.assertTrue(result['warnings']) + self.assertFalse(result['restarted']) + + def test_valid_domain_and_ip_calls_apply_all_dns_rules(self): + """Valid inputs must call firewall_manager.apply_all_dns_rules().""" + nm = self._get_network_manager(self.tmp) + with patch.object(firewall_manager, 'apply_all_dns_rules', return_value=True) as mock_apply, \ + patch.object(firewall_manager, 'reload_coredns', return_value=True): + result = nm.add_cell_dns_forward('valid.cell', '10.1.0.1') + mock_apply.assert_called_once() + call_kwargs = mock_apply.call_args + # cell_links kwarg must include the new entry + cell_links_arg = call_kwargs[1].get('cell_links') or call_kwargs[0][3] + domains = [l['domain'] for l in cell_links_arg] + self.assertIn('valid.cell', domains) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_cells_endpoints.py b/tests/test_cells_endpoints.py new file mode 100644 index 0000000..6a2c351 --- /dev/null +++ b/tests/test_cells_endpoints.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +""" +Unit tests for cell management Flask endpoints in api/app.py. + +Covers: + GET /api/cells/invite — generate invite package + GET /api/cells — list connected cells + POST /api/cells — connect to a remote cell + DELETE /api/cells/ — disconnect from a cell + GET /api/cells//status — live status for a connected cell +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + +# Minimal set of required fields for POST /api/cells +_VALID_CELL_BODY = { + 'cell_name': 'remotecell', + 'public_key': 'abc123publickey==', + 'vpn_subnet': '10.1.0.0/24', + 'dns_ip': '10.1.0.1', + 'domain': 'remotecell.cell', +} + + +class TestGetCellInvite(unittest.TestCase): + """GET /api/cells/invite""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + @patch('app.config_manager') + def test_get_invite_returns_200_with_invite_dict(self, mock_cfg, mock_clm): + mock_cfg.configs = {'_identity': {'cell_name': 'mycell', 'domain': 'cell'}} + mock_clm.generate_invite.return_value = { + 'cell_name': 'mycell', + 'public_key': 'server_pub_key==', + 'vpn_subnet': '10.0.0.0/24', + 'dns_ip': '10.0.0.1', + 'domain': 'cell', + } + r = self.client.get('/api/cells/invite') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('cell_name', data) + self.assertIn('public_key', data) + + @patch('app.cell_link_manager') + @patch('app.config_manager') + def test_get_invite_passes_cell_name_and_domain(self, mock_cfg, mock_clm): + mock_cfg.configs = {'_identity': {'cell_name': 'myhome', 'domain': 'home'}} + mock_clm.generate_invite.return_value = {} + self.client.get('/api/cells/invite') + mock_clm.generate_invite.assert_called_once_with('myhome', 'home') + + @patch('app.cell_link_manager') + @patch('app.config_manager') + def test_get_invite_returns_500_on_exception(self, mock_cfg, mock_clm): + mock_cfg.configs = {'_identity': {}} + mock_clm.generate_invite.side_effect = Exception('WireGuard key unavailable') + r = self.client.get('/api/cells/invite') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestListCellConnections(unittest.TestCase): + """GET /api/cells""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_list_cells_returns_200_with_list(self, mock_clm): + mock_clm.list_connections.return_value = [ + {'cell_name': 'remotecell', 'domain': 'remotecell.cell', 'status': 'connected'}, + ] + r = self.client.get('/api/cells') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 1) + self.assertEqual(data[0]['cell_name'], 'remotecell') + + @patch('app.cell_link_manager') + def test_list_cells_returns_empty_list_when_none_connected(self, mock_clm): + mock_clm.list_connections.return_value = [] + r = self.client.get('/api/cells') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.cell_link_manager') + def test_list_cells_returns_500_on_exception(self, mock_clm): + mock_clm.list_connections.side_effect = Exception('storage error') + r = self.client.get('/api/cells') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddCellConnection(unittest.TestCase): + """POST /api/cells""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_add_cell_returns_201_on_success(self, mock_clm): + mock_clm.add_connection.return_value = {'cell_name': 'remotecell'} + r = self.client.post( + '/api/cells', + data=json.dumps(_VALID_CELL_BODY), + content_type='application/json', + ) + self.assertEqual(r.status_code, 201) + data = json.loads(r.data) + self.assertIn('message', data) + self.assertIn('link', data) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_no_body(self, mock_clm): + r = self.client.post('/api/cells') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_clm.add_connection.assert_not_called() + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_cell_name_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'cell_name'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_public_key_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'public_key'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_vpn_subnet_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'vpn_subnet'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_dns_ip_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'dns_ip'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_when_domain_missing(self, mock_clm): + body = {k: v for k, v in _VALID_CELL_BODY.items() if k != 'domain'} + r = self.client.post( + '/api/cells', + data=json.dumps(body), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_400_on_value_error_from_manager(self, mock_clm): + mock_clm.add_connection.side_effect = ValueError('cell already connected') + r = self.client.post( + '/api/cells', + data=json.dumps(_VALID_CELL_BODY), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_add_cell_returns_500_on_unexpected_exception(self, mock_clm): + mock_clm.add_connection.side_effect = Exception('WireGuard peer add failed') + r = self.client.post( + '/api/cells', + data=json.dumps(_VALID_CELL_BODY), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestRemoveCellConnection(unittest.TestCase): + """DELETE /api/cells/""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_remove_cell_returns_200_on_success(self, mock_clm): + mock_clm.remove_connection.return_value = None + r = self.client.delete('/api/cells/remotecell') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.cell_link_manager') + def test_remove_cell_passes_cell_name_to_manager(self, mock_clm): + mock_clm.remove_connection.return_value = None + self.client.delete('/api/cells/faraway') + mock_clm.remove_connection.assert_called_once_with('faraway') + + @patch('app.cell_link_manager') + def test_remove_cell_returns_404_on_value_error(self, mock_clm): + mock_clm.remove_connection.side_effect = ValueError('cell not found') + r = self.client.delete('/api/cells/nonexistent') + self.assertEqual(r.status_code, 404) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_remove_cell_returns_500_on_unexpected_exception(self, mock_clm): + mock_clm.remove_connection.side_effect = Exception('storage corruption') + r = self.client.delete('/api/cells/remotecell') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetCellConnectionStatus(unittest.TestCase): + """GET /api/cells//status""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.cell_link_manager') + def test_get_cell_status_returns_200_with_status_dict(self, mock_clm): + mock_clm.get_connection_status.return_value = { + 'cell_name': 'remotecell', + 'online': True, + 'last_handshake': '2026-04-27T09:00:00Z', + 'transfer_rx': 1024, + 'transfer_tx': 2048, + } + r = self.client.get('/api/cells/remotecell/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('online', data) + self.assertTrue(data['online']) + + @patch('app.cell_link_manager') + def test_get_cell_status_passes_cell_name(self, mock_clm): + mock_clm.get_connection_status.return_value = {} + self.client.get('/api/cells/faraway/status') + mock_clm.get_connection_status.assert_called_once_with('faraway') + + @patch('app.cell_link_manager') + def test_get_cell_status_returns_404_on_value_error(self, mock_clm): + mock_clm.get_connection_status.side_effect = ValueError('cell not found') + r = self.client.get('/api/cells/missing/status') + self.assertEqual(r.status_code, 404) + self.assertIn('error', json.loads(r.data)) + + @patch('app.cell_link_manager') + def test_get_cell_status_returns_500_on_unexpected_exception(self, mock_clm): + mock_clm.get_connection_status.side_effect = Exception('WireGuard query failed') + r = self.client.get('/api/cells/remotecell/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_email_endpoints.py b/tests/test_email_endpoints.py index 0ce01fe..57cbd9e 100644 --- a/tests/test_email_endpoints.py +++ b/tests/test_email_endpoints.py @@ -1 +1,212 @@ -# ... moved and adapted code from test_phase3_endpoints.py (email section) ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for email Flask endpoints in api/app.py. + +Covers: + GET /api/email/users + POST /api/email/users + DELETE /api/email/users/ + GET /api/email/status + GET /api/email/connectivity +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetEmailUsers(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_get_users_returns_200_with_list(self, mock_em): + mock_em.get_users.return_value = [ + {'username': 'alice@cell', 'domain': 'cell'}, + {'username': 'bob@cell', 'domain': 'cell'}, + ] + r = self.client.get('/api/email/users') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.email_manager') + def test_get_users_returns_empty_list_when_no_users(self, mock_em): + mock_em.get_users.return_value = [] + r = self.client.get('/api/email/users') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.email_manager') + def test_get_users_returns_500_on_exception(self, mock_em): + mock_em.get_users.side_effect = Exception('mailbox unreachable') + r = self.client.get('/api/email/users') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestCreateEmailUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_create_user_returns_200_on_success(self, mock_em): + mock_em.create_email_user.return_value = True + r = self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('created', data) + + @patch('app.email_manager') + def test_create_user_calls_manager_with_username_and_password(self, mock_em): + mock_em.create_email_user.return_value = True + self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + mock_em.create_email_user.assert_called_once() + args = mock_em.create_email_user.call_args[0] + self.assertEqual(args[0], 'alice') + self.assertEqual(args[2], 'secret123') + + @patch('app.email_manager') + def test_create_user_returns_400_when_username_missing(self, mock_em): + r = self.client.post( + '/api/email/users', + data=json.dumps({'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_em.create_email_user.assert_not_called() + + @patch('app.email_manager') + def test_create_user_returns_400_when_password_missing(self, mock_em): + r = self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_em.create_email_user.assert_not_called() + + @patch('app.email_manager') + def test_create_user_returns_400_when_no_body(self, mock_em): + r = self.client.post('/api/email/users') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.email_manager') + def test_create_user_returns_500_on_exception(self, mock_em): + mock_em.create_email_user.side_effect = Exception('SMTP config error') + r = self.client.post( + '/api/email/users', + data=json.dumps({'username': 'alice', 'password': 'secret123'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteEmailUser(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_delete_user_returns_200_on_success(self, mock_em): + mock_em.delete_email_user.return_value = True + r = self.client.delete('/api/email/users/alice') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('deleted', data) + + @patch('app.email_manager') + def test_delete_user_calls_manager_with_username(self, mock_em): + mock_em.delete_email_user.return_value = True + self.client.delete('/api/email/users/bob') + mock_em.delete_email_user.assert_called_once() + args = mock_em.delete_email_user.call_args[0] + self.assertEqual(args[0], 'bob') + + @patch('app.email_manager') + def test_delete_user_returns_500_on_exception(self, mock_em): + mock_em.delete_email_user.side_effect = Exception('LDAP error') + r = self.client.delete('/api/email/users/alice') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetEmailStatus(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_get_status_returns_200_with_status_dict(self, mock_em): + mock_em.get_status.return_value = { + 'running': True, + 'smtp_port': 25, + 'imap_port': 993, + } + r = self.client.get('/api/email/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('running', data) + + @patch('app.email_manager') + def test_get_status_returns_500_on_exception(self, mock_em): + mock_em.get_status.side_effect = Exception('container not found') + r = self.client.get('/api/email/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestEmailConnectivity(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.email_manager') + def test_connectivity_returns_200_with_result(self, mock_em): + mock_em.test_connectivity.return_value = { + 'smtp': True, + 'imap': True, + 'latency_ms': 12, + } + r = self.client.get('/api/email/connectivity') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('smtp', data) + + @patch('app.email_manager') + def test_connectivity_returns_500_on_exception(self, mock_em): + mock_em.test_connectivity.side_effect = Exception('timeout') + r = self.client.get('/api/email/connectivity') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_enforce_auth_configured.py b/tests/test_enforce_auth_configured.py new file mode 100644 index 0000000..075832c --- /dev/null +++ b/tests/test_enforce_auth_configured.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" +Tests for the enforce_auth before_request hook in api/app.py. + +The hook has two distinct behaviours depending on the auth store state: + - users file exists and is POPULATED → auth is enforced (unauthenticated → 401) + - users file exists but is EMPTY → 503 (auth not configured) + - users file does not exist / unreadable → bypass (pre-auth compat mode) + +These tests create real AuthManager instances pointing at tmp directories so +that list_users() and the file-readability check both behave exactly as they +do in production. +""" + +import os +import sys +import json +import pytest +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, str(Path(__file__).parent.parent / 'api')) + + +@pytest.fixture +def flask_client(): + from app import app + app.config['TESTING'] = True + return app.test_client() + + +@pytest.fixture +def populated_auth_manager(tmp_path): + """AuthManager whose users file contains at least one admin account.""" + from auth_manager import AuthManager + data_dir = str(tmp_path / 'data') + config_dir = str(tmp_path / 'config') + os.makedirs(data_dir, exist_ok=True) + os.makedirs(config_dir, exist_ok=True) + mgr = AuthManager(data_dir=data_dir, config_dir=config_dir) + # Create an admin so list_users() is non-empty + ok = mgr.create_user('admin', 'AdminPass123!', 'admin') + assert ok, 'Could not seed admin user for test' + return mgr + + +@pytest.fixture +def empty_auth_manager(tmp_path): + """AuthManager whose users file exists and is readable but contains no users.""" + from auth_manager import AuthManager + data_dir = str(tmp_path / 'data') + config_dir = str(tmp_path / 'config') + os.makedirs(data_dir, exist_ok=True) + os.makedirs(config_dir, exist_ok=True) + mgr = AuthManager(data_dir=data_dir, config_dir=config_dir) + # The constructor creates the file with '[]' (empty list). We do NOT add + # any user, so list_users() returns [] but the file is readable. + assert mgr.list_users() == [], 'Expected empty user list' + return mgr + + +# ── populated store → auth enforced ────────────────────────────────────────── + +def test_populated_auth_manager_unauthenticated_request_gets_401( + flask_client, populated_auth_manager +): + """When the auth store has users, unauthenticated API requests must get 401.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/api/status') + assert r.status_code == 401 + data = json.loads(r.data) + assert 'error' in data + + +def test_populated_auth_manager_401_body_says_not_authenticated( + flask_client, populated_auth_manager +): + """The 401 body must clearly indicate the session is missing.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/api/peers') + assert r.status_code == 401 + data = json.loads(r.data) + assert 'Not authenticated' in data.get('error', '') + + +def test_populated_auth_manager_non_api_path_bypasses_auth( + flask_client, populated_auth_manager +): + """Non-API paths like /health must always be public.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/health') + assert r.status_code == 200 + + +def test_populated_auth_manager_auth_namespace_bypasses_auth( + flask_client, populated_auth_manager +): + """The /api/auth/* namespace must always be accessible without a session.""" + with patch('app.auth_manager', populated_auth_manager): + r = flask_client.get('/api/auth/me') + # /api/auth/me may return 401 from the route itself (no session), but it + # must NOT be blocked by enforce_auth; the enforce_auth hook must return None + # for /api/auth/* paths. The status must not be 503. + assert r.status_code != 503 + + +# ── empty store → 503 ──────────────────────────────────────────────────────── + +def test_empty_auth_manager_returns_503_for_api_requests( + flask_client, empty_auth_manager +): + """When the users file exists and is readable but empty, /api/* must get 503.""" + with patch('app.auth_manager', empty_auth_manager): + r = flask_client.get('/api/status') + assert r.status_code == 503 + data = json.loads(r.data) + assert 'error' in data + + +def test_empty_auth_manager_503_body_mentions_configuration( + flask_client, empty_auth_manager +): + """The 503 error body must mention that auth is not configured.""" + with patch('app.auth_manager', empty_auth_manager): + r = flask_client.get('/api/config') + assert r.status_code == 503 + data = json.loads(r.data) + error_text = data.get('error', '') + assert 'not configured' in error_text.lower() or 'Authentication' in error_text + + +def test_empty_auth_manager_non_api_path_bypasses_503( + flask_client, empty_auth_manager +): + """Even with an empty auth store, /health must remain accessible.""" + with patch('app.auth_manager', empty_auth_manager): + r = flask_client.get('/health') + assert r.status_code == 200 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/test_file_endpoints.py b/tests/test_file_endpoints.py index 15ba155..e9080f2 100644 --- a/tests/test_file_endpoints.py +++ b/tests/test_file_endpoints.py @@ -231,7 +231,7 @@ class TestFileCreateFolderEndpoint(unittest.TestCase): mock_fm.create_folder.return_value = True r = self.client.post( '/api/files/folders', - data=json.dumps({'username': 'alice', 'folder': 'Archive'}), + data=json.dumps({'username': 'alice', 'folder_path': 'Archive'}), content_type='application/json', ) self.assertEqual(r.status_code, 200) @@ -247,7 +247,7 @@ class TestFileCreateFolderEndpoint(unittest.TestCase): mock_fm.create_folder.side_effect = Exception('quota exceeded') r = self.client.post( '/api/files/folders', - data=json.dumps({'username': 'alice', 'folder': 'NewFolder'}), + data=json.dumps({'username': 'alice', 'folder_path': 'NewFolder'}), content_type='application/json', ) self.assertEqual(r.status_code, 500) diff --git a/tests/test_firewall_manager.py b/tests/test_firewall_manager.py index 024d152..cae337d 100644 --- a/tests/test_firewall_manager.py +++ b/tests/test_firewall_manager.py @@ -30,10 +30,12 @@ def _make_peer(ip, internet=True, services=None, peers=True): class TestPeerComment(unittest.TestCase): def test_dots_replaced_with_dashes(self): - self.assertEqual(firewall_manager._peer_comment('10.0.0.2'), 'pic-peer-10-0-0-2') + # Comment format now includes /32 suffix to prevent substring matches + # (e.g. pic-peer-10-0-0-1/32 is not a prefix of pic-peer-10-0-0-10/32) + self.assertEqual(firewall_manager._peer_comment('10.0.0.2'), 'pic-peer-10-0-0-2/32') def test_different_ip(self): - self.assertEqual(firewall_manager._peer_comment('192.168.1.100'), 'pic-peer-192-168-1-100') + self.assertEqual(firewall_manager._peer_comment('192.168.1.100'), 'pic-peer-192-168-1-100/32') # --------------------------------------------------------------------------- @@ -115,6 +117,87 @@ class TestGenerateCorefile(unittest.TestCase): self.assertFalse(result) +# --------------------------------------------------------------------------- +# generate_corefile with cell_links +# --------------------------------------------------------------------------- + +class TestGenerateCorefileWithCellLinks(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.path = os.path.join(self.tmp, 'Corefile') + + def tearDown(self): + shutil.rmtree(self.tmp) + + def _content(self): + return open(self.path).read() + + def test_cell_links_none_produces_no_forwarding_stanzas(self): + """Default (None) produces no extra forwarding blocks beyond the primary zone.""" + firewall_manager.generate_corefile([], self.path, cell_links=None) + content = self._content() + # The only 'forward' line should be the default internet forwarder + forward_lines = [l for l in content.splitlines() if 'forward' in l] + self.assertEqual(len(forward_lines), 1) + self.assertIn('8.8.8.8', forward_lines[0]) + + def test_cell_links_empty_list_produces_no_extra_stanzas(self): + """An empty cell_links list produces no extra forwarding blocks.""" + firewall_manager.generate_corefile([], self.path, cell_links=[]) + content = self._content() + forward_lines = [l for l in content.splitlines() if 'forward' in l] + self.assertEqual(len(forward_lines), 1) + self.assertIn('8.8.8.8', forward_lines[0]) + + def test_single_cell_link_produces_forwarding_block(self): + """One cell link produces one forwarding stanza with correct domain and dns_ip.""" + cell_links = [{'domain': 'remote.cell', 'dns_ip': '10.1.0.1'}] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + self.assertIn('remote.cell {', content) + self.assertIn('forward . 10.1.0.1', content) + self.assertIn('cache', content) + + def test_multiple_cell_links_produce_multiple_forwarding_blocks(self): + """Multiple cell links produce one stanza each.""" + cell_links = [ + {'domain': 'alpha.cell', 'dns_ip': '10.1.0.1'}, + {'domain': 'beta.cell', 'dns_ip': '10.2.0.1'}, + ] + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + self.assertIn('alpha.cell {', content) + self.assertIn('forward . 10.1.0.1', content) + self.assertIn('beta.cell {', content) + self.assertIn('forward . 10.2.0.1', content) + + def test_cell_links_do_not_overwrite_peer_acls(self): + """Cell link stanzas are appended; peer ACLs in the primary zone survive.""" + peers = [_make_peer('10.0.0.3', services=['calendar'])] + cell_links = [{'domain': 'other.cell', 'dns_ip': '10.99.0.1'}] + firewall_manager.generate_corefile(peers, self.path, cell_links=cell_links) + content = self._content() + self.assertIn('block net 10.0.0.3/32', content) + self.assertIn('other.cell {', content) + self.assertIn('forward . 10.99.0.1', content) + + def test_link_with_missing_domain_is_skipped(self): + """A cell_link entry with no domain key is silently skipped.""" + cell_links = [{'dns_ip': '10.1.0.1'}] # no 'domain' + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + # Only the default internet forwarder + forward_lines = [l for l in content.splitlines() if 'forward' in l] + self.assertEqual(len(forward_lines), 1) + + def test_link_with_missing_dns_ip_is_skipped(self): + """A cell_link entry with no dns_ip key is silently skipped.""" + cell_links = [{'domain': 'nope.cell'}] # no 'dns_ip' + firewall_manager.generate_corefile([], self.path, cell_links=cell_links) + content = self._content() + self.assertNotIn('nope.cell', content) + + # --------------------------------------------------------------------------- # apply_peer_rules — iptables call verification # --------------------------------------------------------------------------- @@ -227,8 +310,8 @@ class TestClearPeerRules(unittest.TestCase): '*filter\n' ':INPUT ACCEPT [0:0]\n' ':FORWARD ACCEPT [0:0]\n' - '-A FORWARD -s 10.0.0.2 -m comment --comment pic-peer-10-0-0-2 -j ACCEPT\n' - '-A FORWARD -s 10.0.0.3 -m comment --comment pic-peer-10-0-0-3 -j DROP\n' + '-A FORWARD -s 10.0.0.2 -m comment --comment "pic-peer-10-0-0-2/32" -j ACCEPT\n' + '-A FORWARD -s 10.0.0.3 -m comment --comment "pic-peer-10-0-0-3/32" -j DROP\n' 'COMMIT\n' ) restored = [] @@ -252,8 +335,8 @@ class TestClearPeerRules(unittest.TestCase): self.assertEqual(len(restored), 1) restored_content = restored[0] - self.assertNotIn('pic-peer-10-0-0-2', restored_content) - self.assertIn('pic-peer-10-0-0-3', restored_content) + self.assertNotIn('pic-peer-10-0-0-2/32', restored_content) + self.assertIn('pic-peer-10-0-0-3/32', restored_content) def test_no_op_when_no_matching_rules(self): save_output = '*filter\n:FORWARD ACCEPT [0:0]\nCOMMIT\n' diff --git a/tests/test_input_validation.py b/tests/test_input_validation.py new file mode 100644 index 0000000..7b19242 --- /dev/null +++ b/tests/test_input_validation.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +""" +Tests for the security input validation on PUT /api/config. + +Validates that domain and cell_name fields reject injection characters +while allowing legitimate values (multi-label domains, hyphens, etc.). +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +def _put(client, payload): + return client.put( + '/api/config', + data=json.dumps(payload), + content_type='application/json', + ) + + +class TestDomainValidation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + def test_domain_with_newline_returns_400(self): + r = _put(self.client, {'domain': 'cell\nnewline'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_with_opening_brace_returns_400(self): + r = _put(self.client, {'domain': 'cell{injection}'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_with_semicolon_returns_400(self): + r = _put(self.client, {'domain': 'cell;rm -rf /'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_with_space_returns_400(self): + r = _put(self.client, {'domain': 'my cell'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_domain_multilabel_with_dot_returns_200(self): + # Multi-label names like 'cell.local' or 'home.lan' must be accepted. + r = _put(self.client, {'domain': 'cell.local'}) + # The endpoint may also return non-400 on 500 if downstream fails, + # but the validation itself must not reject dots. + self.assertNotEqual(r.status_code, 400) + + def test_domain_simple_word_returns_200(self): + r = _put(self.client, {'domain': 'myhome'}) + self.assertNotEqual(r.status_code, 400) + + def test_domain_with_hyphen_returns_200(self): + r = _put(self.client, {'domain': 'my-cell'}) + self.assertNotEqual(r.status_code, 400) + + def test_domain_with_at_sign_returns_400(self): + r = _put(self.client, {'domain': 'cell@evil.com'}) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + def test_domain_with_slash_returns_400(self): + r = _put(self.client, {'domain': 'cell/etc'}) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + +class TestCellNameValidation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + def test_cell_name_with_space_returns_400(self): + r = _put(self.client, {'cell_name': 'my cell'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_with_dot_returns_400(self): + # cell_name is a single hostname component — dots are not allowed + r = _put(self.client, {'cell_name': 'my.cell'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_with_newline_returns_400(self): + r = _put(self.client, {'cell_name': 'cell\nevil'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_with_semicolon_returns_400(self): + r = _put(self.client, {'cell_name': 'cell;drop'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + def test_cell_name_valid_hyphenated_returns_200(self): + r = _put(self.client, {'cell_name': 'valid-name'}) + self.assertNotEqual(r.status_code, 400) + + def test_cell_name_simple_alpha_returns_200(self): + r = _put(self.client, {'cell_name': 'mycell'}) + self.assertNotEqual(r.status_code, 400) + + def test_cell_name_with_digits_returns_200(self): + r = _put(self.client, {'cell_name': 'cell01'}) + self.assertNotEqual(r.status_code, 400) + + def test_cell_name_with_brace_returns_400(self): + r = _put(self.client, {'cell_name': 'cell{x}'}) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_is_local_request_per_endpoint.py b/tests/test_is_local_request_per_endpoint.py new file mode 100644 index 0000000..b8b0934 --- /dev/null +++ b/tests/test_is_local_request_per_endpoint.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +""" +Tests verifying that is_local_request() enforcement works correctly +per endpoint in api/app.py. + +The audit flagged that is_local_request() checks are performed inline +(not via a decorator), so this file confirms: + 1. Endpoints that call `is_local_request()` return 403 when the + function returns False (i.e., a non-local caller). + 2. Endpoints that do NOT call `is_local_request()` still respond + normally (2xx / 4xx) for non-local callers. + +Tested local-only endpoints (representative sample): + GET /api/containers — list_containers + POST /api/containers//start + POST /api/containers//stop + POST /api/containers//restart + GET /api/containers//logs + GET /api/containers//stats + GET /api/vault/secrets + POST /api/vault/secrets + GET /api/vault/secrets/ + DELETE /api/vault/secrets/ + GET /api/containers — POST with image field + GET /api/images + POST /api/images/pull + DELETE /api/images/ + GET /api/volumes + POST /api/volumes + DELETE /api/volumes/ + DELETE /api/containers/ + +Tested public endpoints (no is_local_request guard): + GET /api/calendar/status + GET /api/dns/records + GET /api/dhcp/leases + GET /api/cells +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +def _non_local_client(): + """Return a Flask test client that pretends to come from a non-local address.""" + app.config['TESTING'] = True + # Flask's test client uses '127.0.0.1' by default; override with a public IP + # by setting REMOTE_ADDR in the environ base. + return app.test_client() + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _get_non_local(client, path): + """Perform a GET request that appears to originate from a non-local IP.""" + return client.get(path, environ_base={'REMOTE_ADDR': '203.0.113.1'}) + + +def _post_non_local(client, path, body=None): + return client.post( + path, + data=json.dumps(body or {}), + content_type='application/json', + environ_base={'REMOTE_ADDR': '203.0.113.1'}, + ) + + +def _delete_non_local(client, path): + return client.delete(path, environ_base={'REMOTE_ADDR': '203.0.113.1'}) + + +# ── local-only endpoint tests ───────────────────────────────────────────────── + +class TestLocalOnlyEndpointsReturn403ForNonLocal(unittest.TestCase): + """Every endpoint that calls is_local_request() must return 403 for external IPs.""" + + def setUp(self): + app.config['TESTING'] = True + self.client = _non_local_client() + + # Container management + + def test_list_containers_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/containers') + self.assertEqual(r.status_code, 403) + self.assertIn('error', json.loads(r.data)) + + def test_start_container_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/containers/myapp/start') + self.assertEqual(r.status_code, 403) + + def test_stop_container_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/containers/myapp/stop') + self.assertEqual(r.status_code, 403) + + def test_restart_container_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/containers/myapp/restart') + self.assertEqual(r.status_code, 403) + + def test_get_container_logs_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/containers/myapp/logs') + self.assertEqual(r.status_code, 403) + + def test_get_container_stats_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/containers/myapp/stats') + self.assertEqual(r.status_code, 403) + + def test_remove_container_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/containers/myapp') + self.assertEqual(r.status_code, 403) + + # Image management + + def test_list_images_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/images') + self.assertEqual(r.status_code, 403) + + def test_pull_image_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/images/pull', {'image': 'nginx:latest'}) + self.assertEqual(r.status_code, 403) + + def test_remove_image_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/images/nginx') + self.assertEqual(r.status_code, 403) + + # Volume management + + def test_list_volumes_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/volumes') + self.assertEqual(r.status_code, 403) + + def test_create_volume_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/volumes', {'name': 'myvol'}) + self.assertEqual(r.status_code, 403) + + def test_remove_volume_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/volumes/myvol') + self.assertEqual(r.status_code, 403) + + # Vault endpoints + + def test_list_secrets_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/vault/secrets') + self.assertEqual(r.status_code, 403) + + def test_store_secret_returns_403_for_non_local(self): + r = _post_non_local(self.client, '/api/vault/secrets', {'name': 'k', 'value': 'v'}) + self.assertEqual(r.status_code, 403) + + def test_get_secret_returns_403_for_non_local(self): + r = _get_non_local(self.client, '/api/vault/secrets/mykey') + self.assertEqual(r.status_code, 403) + + def test_delete_secret_returns_403_for_non_local(self): + r = _delete_non_local(self.client, '/api/vault/secrets/mykey') + self.assertEqual(r.status_code, 403) + + +class TestLocalOnlyEndpointsAllowedFromLocalhost(unittest.TestCase): + """The same endpoints must NOT return 403 for loopback / local callers.""" + + def setUp(self): + app.config['TESTING'] = True + # Default test client remote_addr is 127.0.0.1, which is local + self.client = app.test_client() + + @patch('app.container_manager') + def test_list_containers_allowed_from_local(self, mock_cm): + mock_cm.list_containers.return_value = [] + r = self.client.get('/api/containers') + self.assertNotEqual(r.status_code, 403) + + @patch('app.container_manager') + def test_list_images_allowed_from_local(self, mock_cm): + mock_cm.list_images.return_value = [] + r = self.client.get('/api/images') + self.assertNotEqual(r.status_code, 403) + + @patch('app.container_manager') + def test_list_volumes_allowed_from_local(self, mock_cm): + mock_cm.list_volumes.return_value = [] + r = self.client.get('/api/volumes') + self.assertNotEqual(r.status_code, 403) + + +# ── public endpoint tests — no is_local_request guard ──────────────────────── + +class TestPublicEndpointsNotBlockedForNonLocal(unittest.TestCase): + """ + Endpoints that do NOT call is_local_request() must remain reachable + from non-local addresses. A 403 here would indicate an unintended + broadening of the local-only guard. + """ + + def setUp(self): + app.config['TESTING'] = True + self.client = _non_local_client() + + @patch('app.calendar_manager') + def test_calendar_status_not_403_for_non_local(self, mock_cm): + mock_cm.get_status.return_value = {'running': True} + r = _get_non_local(self.client, '/api/calendar/status') + self.assertNotEqual(r.status_code, 403) + + @patch('app.network_manager') + def test_dns_records_not_403_for_non_local(self, mock_nm): + mock_nm.get_dns_records.return_value = [] + r = _get_non_local(self.client, '/api/dns/records') + self.assertNotEqual(r.status_code, 403) + + @patch('app.network_manager') + def test_dhcp_leases_not_403_for_non_local(self, mock_nm): + mock_nm.get_dhcp_leases.return_value = [] + r = _get_non_local(self.client, '/api/dhcp/leases') + self.assertNotEqual(r.status_code, 403) + + @patch('app.cell_link_manager') + def test_cells_list_not_403_for_non_local(self, mock_clm): + mock_clm.list_connections.return_value = [] + r = _get_non_local(self.client, '/api/cells') + self.assertNotEqual(r.status_code, 403) + + def test_health_check_not_403_for_non_local(self): + r = _get_non_local(self.client, '/health') + self.assertNotEqual(r.status_code, 403) + + +# ── is_local_request logic unit tests ──────────────────────────────────────── + +class TestIsLocalRequestLogic(unittest.TestCase): + """ + Directly verify the is_local_request() function from app.py. + These tests exercise the address-checking logic without going through + a full HTTP request cycle. + """ + + def setUp(self): + from app import is_local_request as _fn + self._fn = _fn + app.config['TESTING'] = True + + def _call_with_addr(self, remote_addr, xff=None): + """Push a fake request context and evaluate is_local_request().""" + from app import app as _app + headers = {} + if xff: + headers['X-Forwarded-For'] = xff + with _app.test_request_context('/', environ_base={'REMOTE_ADDR': remote_addr}, + headers=headers): + return self._fn() + + def test_loopback_127_is_local(self): + self.assertTrue(self._call_with_addr('127.0.0.1')) + + def test_ipv6_loopback_is_local(self): + self.assertTrue(self._call_with_addr('::1')) + + def test_docker_bridge_172_20_is_local(self): + # 172.20.x.x is inside 172.16.0.0/12 + self.assertTrue(self._call_with_addr('172.20.0.5')) + + def test_docker_bridge_172_16_boundary_is_local(self): + # Exact boundary of 172.16.0.0/12 + self.assertTrue(self._call_with_addr('172.16.0.1')) + + def test_public_ip_is_not_local(self): + self.assertFalse(self._call_with_addr('8.8.8.8')) + + def test_wireguard_peer_10_0_0_x_is_not_local(self): + # WireGuard peer IPs (10.0.0.0/8) must NOT be treated as local + self.assertFalse(self._call_with_addr('10.0.0.2')) + + def test_lan_192_168_is_not_local(self): + # LAN addresses must NOT be treated as local (comment in app.py confirms this) + self.assertFalse(self._call_with_addr('192.168.1.50')) + + def test_xff_last_entry_loopback_is_local(self): + # Public remote addr, but last XFF entry is loopback → allowed + self.assertTrue(self._call_with_addr('8.8.8.8', xff='8.8.8.8, 127.0.0.1')) + + def test_xff_first_entry_spoofed_loopback_not_local(self): + # Spoofed first XFF entry; last entry is a public IP → should be rejected + # remote_addr is also public to rule out that shortcut + result = self._call_with_addr('8.8.8.8', xff='127.0.0.1, 8.8.8.8') + self.assertFalse(result) + + def test_xff_last_entry_docker_bridge_is_local(self): + # Last XFF entry is Caddy's Docker bridge address + self.assertTrue(self._call_with_addr('8.8.8.8', xff='1.2.3.4, 172.20.0.2')) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_logs_endpoints.py b/tests/test_logs_endpoints.py new file mode 100644 index 0000000..80d7a7c --- /dev/null +++ b/tests/test_logs_endpoints.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +Unit tests for logs Flask endpoints in api/app.py. + +Covers: + GET /api/logs — backend log file (reads picell.log) + GET /api/logs/services/ — per-service logs via log_manager + POST /api/logs/search — search across services + POST /api/logs/export — export logs + GET /api/logs/statistics — log stats + POST /api/logs/rotate — rotate logs + GET /api/logs/files — list log file info + GET /api/logs/verbosity — get log levels + PUT /api/logs/verbosity — set log levels +""" + +import sys +import json +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock, mock_open + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetBackendLogs(unittest.TestCase): + """GET /api/logs — reads picell.log from api directory.""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + def test_get_logs_returns_404_when_log_file_missing(self): + # Patch os.path.exists so the log file appears absent + with patch('app.os.path.exists', return_value=False): + r = self.client.get('/api/logs') + self.assertEqual(r.status_code, 404) + self.assertIn('error', json.loads(r.data)) + + def test_get_logs_returns_200_with_log_content(self): + log_content = 'INFO 2026-04-27 server started\nERROR something went wrong\n' + m = mock_open(read_data=log_content) + # Bypass auth enforcement by replacing auth_manager with a non-AuthManager object + with patch('app.auth_manager', MagicMock(spec=object)), \ + patch('app.os.path.exists', return_value=True), \ + patch('builtins.open', m): + r = self.client.get('/api/logs') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('log', data) + + def test_get_logs_respects_lines_query_param(self): + # Produce 10 lines; request only last 3 + all_lines = [f'line {i}\n' for i in range(10)] + m = mock_open(read_data=''.join(all_lines)) + m.return_value.__iter__ = lambda s: iter(all_lines) + m.return_value.readlines = lambda: all_lines + # Bypass auth enforcement by replacing auth_manager with a non-AuthManager object + with patch('app.auth_manager', MagicMock(spec=object)), \ + patch('app.os.path.exists', return_value=True), \ + patch('builtins.open', m): + r = self.client.get('/api/logs?lines=3') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + # The tail should contain only the last 3 lines + self.assertIn('line 7', data['log']) + self.assertIn('line 8', data['log']) + self.assertIn('line 9', data['log']) + + def test_get_logs_returns_500_on_exception(self): + with patch('app.os.path.exists', return_value=True), \ + patch('builtins.open', side_effect=PermissionError('access denied')): + r = self.client.get('/api/logs') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetServiceLogs(unittest.TestCase): + """GET /api/logs/services/""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_service_logs_returns_200_with_log_data(self, mock_lm): + mock_lm.get_service_logs.return_value = [ + '[INFO] 2026-04-27 dns started', + '[WARN] 2026-04-27 retry attempt', + ] + r = self.client.get('/api/logs/services/dns') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertEqual(data['service'], 'dns') + self.assertIsInstance(data['logs'], list) + self.assertEqual(len(data['logs']), 2) + + @patch('app.log_manager') + def test_get_service_logs_passes_level_and_lines_params(self, mock_lm): + mock_lm.get_service_logs.return_value = [] + self.client.get('/api/logs/services/email?level=ERROR&lines=25') + mock_lm.get_service_logs.assert_called_once_with('email', 'ERROR', 25) + + @patch('app.log_manager') + def test_get_service_logs_uses_defaults_when_params_absent(self, mock_lm): + mock_lm.get_service_logs.return_value = [] + self.client.get('/api/logs/services/wireguard') + mock_lm.get_service_logs.assert_called_once_with('wireguard', 'INFO', 50) + + @patch('app.log_manager') + def test_get_service_logs_returns_500_on_exception(self, mock_lm): + mock_lm.get_service_logs.side_effect = Exception('log file missing') + r = self.client.get('/api/logs/services/calendar') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestSearchLogs(unittest.TestCase): + """POST /api/logs/search""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_search_logs_returns_200_with_results_and_count(self, mock_lm): + mock_lm.search_logs.return_value = [ + {'service': 'dns', 'line': 'ERROR timeout'}, + ] + r = self.client.post( + '/api/logs/search', + data=json.dumps({'query': 'ERROR', 'services': ['dns']}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('results', data) + self.assertIn('count', data) + self.assertEqual(data['count'], 1) + + @patch('app.log_manager') + def test_search_logs_works_with_empty_body(self, mock_lm): + mock_lm.search_logs.return_value = [] + r = self.client.post('/api/logs/search') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertEqual(data['results'], []) + self.assertEqual(data['count'], 0) + + @patch('app.log_manager') + def test_search_logs_returns_500_on_exception(self, mock_lm): + mock_lm.search_logs.side_effect = Exception('index unavailable') + r = self.client.post( + '/api/logs/search', + data=json.dumps({'query': 'fail'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestExportLogs(unittest.TestCase): + """POST /api/logs/export""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_export_logs_returns_200_with_log_data_and_format(self, mock_lm): + mock_lm.export_logs.return_value = '[{"ts":1,"msg":"ok"}]' + r = self.client.post( + '/api/logs/export', + data=json.dumps({'format': 'json', 'filters': {'service': 'dns'}}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('logs', data) + self.assertIn('format', data) + self.assertEqual(data['format'], 'json') + + @patch('app.log_manager') + def test_export_logs_defaults_to_json_format(self, mock_lm): + mock_lm.export_logs.return_value = '[]' + self.client.post('/api/logs/export') + mock_lm.export_logs.assert_called_once_with('json', {}) + + @patch('app.log_manager') + def test_export_logs_returns_500_on_exception(self, mock_lm): + mock_lm.export_logs.side_effect = Exception('export failed') + r = self.client.post( + '/api/logs/export', + data=json.dumps({'format': 'csv'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetLogStatistics(unittest.TestCase): + """GET /api/logs/statistics""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_statistics_returns_200_with_stats_dict(self, mock_lm): + mock_lm.get_log_statistics.return_value = { + 'total_lines': 1200, + 'error_count': 3, + 'warn_count': 17, + } + r = self.client.get('/api/logs/statistics') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('total_lines', data) + + @patch('app.log_manager') + def test_get_statistics_passes_service_param(self, mock_lm): + mock_lm.get_log_statistics.return_value = {} + self.client.get('/api/logs/statistics?service=email') + mock_lm.get_log_statistics.assert_called_once_with('email') + + @patch('app.log_manager') + def test_get_statistics_passes_none_when_no_service_param(self, mock_lm): + mock_lm.get_log_statistics.return_value = {} + self.client.get('/api/logs/statistics') + mock_lm.get_log_statistics.assert_called_once_with(None) + + @patch('app.log_manager') + def test_get_statistics_returns_500_on_exception(self, mock_lm): + mock_lm.get_log_statistics.side_effect = Exception('stats error') + r = self.client.get('/api/logs/statistics') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestRotateLogs(unittest.TestCase): + """POST /api/logs/rotate""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_rotate_all_logs_returns_200(self, mock_lm): + r = self.client.post('/api/logs/rotate') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + mock_lm.rotate_logs.assert_called_once_with(None) + + @patch('app.log_manager') + def test_rotate_specific_service_passes_service_name(self, mock_lm): + r = self.client.post( + '/api/logs/rotate', + data=json.dumps({'service': 'dns'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + mock_lm.rotate_logs.assert_called_once_with('dns') + + @patch('app.log_manager') + def test_rotate_returns_500_on_exception(self, mock_lm): + mock_lm.rotate_logs.side_effect = Exception('rotate failed') + r = self.client.post('/api/logs/rotate') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetLogFileInfos(unittest.TestCase): + """GET /api/logs/files""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_log_files_returns_200_with_file_list(self, mock_lm): + mock_lm.get_all_log_file_infos.return_value = [ + {'service': 'dns', 'path': '/data/logs/dns.log', 'size_bytes': 4096}, + {'service': 'email', 'path': '/data/logs/email.log', 'size_bytes': 8192}, + ] + r = self.client.get('/api/logs/files') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.log_manager') + def test_get_log_files_returns_500_on_exception(self, mock_lm): + mock_lm.get_all_log_file_infos.side_effect = Exception('filesystem error') + r = self.client.get('/api/logs/files') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestLogVerbosity(unittest.TestCase): + """GET /api/logs/verbosity and PUT /api/logs/verbosity""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.log_manager') + def test_get_verbosity_returns_200_with_levels_map(self, mock_lm): + mock_lm.get_service_levels.return_value = { + 'dns': 'INFO', + 'email': 'DEBUG', + 'wireguard': 'WARNING', + } + r = self.client.get('/api/logs/verbosity') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('dns', data) + self.assertEqual(data['email'], 'DEBUG') + + @patch('app.log_manager') + def test_get_verbosity_returns_500_on_exception(self, mock_lm): + mock_lm.get_service_levels.side_effect = Exception('config missing') + r = self.client.get('/api/logs/verbosity') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + @patch('app.log_manager') + def test_put_verbosity_returns_200_and_calls_set_level(self, mock_lm): + mock_lm.get_service_levels.return_value = {'dns': 'DEBUG'} + with tempfile.TemporaryDirectory() as tmpdir: + # Endpoint builds: os.path.join(os.path.dirname(__file__), 'config', 'log_levels.json') + # Patch dirname to return tmpdir so the full path becomes tmpdir/config/log_levels.json + config_dir = os.path.join(tmpdir, 'config') + os.makedirs(config_dir) + with patch('app.auth_manager', MagicMock(spec=object)), \ + patch('app.os.path.dirname', return_value=tmpdir): + r = self.client.put( + '/api/logs/verbosity', + data=json.dumps({'dns': 'DEBUG'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + mock_lm.set_service_level.assert_called_with('dns', 'DEBUG') + + @patch('app.log_manager') + def test_put_verbosity_returns_500_on_exception(self, mock_lm): + mock_lm.set_service_level.side_effect = Exception('unknown service') + r = self.client.put( + '/api/logs/verbosity', + data=json.dumps({'unknown_svc': 'DEBUG'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_network_endpoints.py b/tests/test_network_endpoints.py index f2579ef..b5b1c14 100644 --- a/tests/test_network_endpoints.py +++ b/tests/test_network_endpoints.py @@ -1 +1,353 @@ -# ... moved and adapted code from test_phase1_endpoints.py ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for network/DNS/DHCP Flask endpoints in api/app.py. + +Covers: + GET /api/dns/records + POST /api/dns/records + DELETE /api/dns/records + GET /api/dns/status + GET /api/dhcp/leases + POST /api/dhcp/reservations + DELETE /api/dhcp/reservations + GET /api/network/info + POST /api/network/test +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestGetDnsRecords(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_dns_records_returns_200_with_list(self, mock_nm): + mock_nm.get_dns_records.return_value = [ + {'name': 'myhost.cell', 'type': 'A', 'value': '192.168.1.10'}, + {'name': 'nas.cell', 'type': 'A', 'value': '192.168.1.20'}, + ] + r = self.client.get('/api/dns/records') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 2) + + @patch('app.network_manager') + def test_get_dns_records_returns_empty_list_when_none(self, mock_nm): + mock_nm.get_dns_records.return_value = [] + r = self.client.get('/api/dns/records') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.network_manager') + def test_get_dns_records_returns_500_on_exception(self, mock_nm): + mock_nm.get_dns_records.side_effect = Exception('CoreDNS unreachable') + r = self.client.get('/api/dns/records') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddDnsRecord(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_add_dns_record_returns_200_on_valid_body(self, mock_nm): + mock_nm.add_dns_record.return_value = {'success': True} + r = self.client.post( + '/api/dns/records', + data=json.dumps({'name': 'printer.cell', 'type': 'A', 'value': '192.168.1.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('success', data) + + @patch('app.network_manager') + def test_add_dns_record_returns_400_when_no_body(self, mock_nm): + r = self.client.post('/api/dns/records') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_nm.add_dns_record.assert_not_called() + + @patch('app.network_manager') + def test_add_dns_record_returns_500_on_exception(self, mock_nm): + mock_nm.add_dns_record.side_effect = Exception('Corefile write failed') + r = self.client.post( + '/api/dns/records', + data=json.dumps({'name': 'bad.cell', 'type': 'A', 'value': '10.0.0.1'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteDnsRecord(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_delete_dns_record_returns_200_on_success(self, mock_nm): + mock_nm.remove_dns_record.return_value = {'success': True} + r = self.client.delete( + '/api/dns/records', + data=json.dumps({'name': 'printer.cell'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + + @patch('app.network_manager') + def test_delete_dns_record_returns_500_on_exception(self, mock_nm): + mock_nm.remove_dns_record.side_effect = Exception('record not found') + r = self.client.delete( + '/api/dns/records', + data=json.dumps({'name': 'missing.cell'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetDnsStatus(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_dns_status_returns_200_with_status_dict(self, mock_nm): + mock_nm.get_dns_status.return_value = { + 'running': True, + 'records_count': 5, + 'upstreams': ['1.1.1.1', '8.8.8.8'], + } + r = self.client.get('/api/dns/status') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('running', data) + + @patch('app.network_manager') + def test_get_dns_status_returns_500_on_exception(self, mock_nm): + mock_nm.get_dns_status.side_effect = Exception('CoreDNS not running') + r = self.client.get('/api/dns/status') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetDhcpLeases(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_dhcp_leases_returns_200_with_list(self, mock_nm): + mock_nm.get_dhcp_leases.return_value = [ + {'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.101', 'hostname': 'laptop'}, + ] + r = self.client.get('/api/dhcp/leases') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIsInstance(data, list) + self.assertEqual(len(data), 1) + self.assertEqual(data[0]['hostname'], 'laptop') + + @patch('app.network_manager') + def test_get_dhcp_leases_returns_empty_list_when_no_leases(self, mock_nm): + mock_nm.get_dhcp_leases.return_value = [] + r = self.client.get('/api/dhcp/leases') + self.assertEqual(r.status_code, 200) + self.assertEqual(json.loads(r.data), []) + + @patch('app.network_manager') + def test_get_dhcp_leases_returns_500_on_exception(self, mock_nm): + mock_nm.get_dhcp_leases.side_effect = Exception('dnsmasq not running') + r = self.client.get('/api/dhcp/leases') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddDhcpReservation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_add_reservation_returns_200_on_valid_body(self, mock_nm): + mock_nm.add_dhcp_reservation.return_value = True + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50', 'hostname': 'printer'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('success', data) + + @patch('app.network_manager') + def test_add_reservation_returns_400_when_no_body(self, mock_nm): + r = self.client.post('/api/dhcp/reservations') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_nm.add_dhcp_reservation.assert_not_called() + + @patch('app.network_manager') + def test_add_reservation_returns_400_when_mac_missing(self, mock_nm): + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'ip': '192.168.1.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.network_manager') + def test_add_reservation_returns_400_when_ip_missing(self, mock_nm): + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.network_manager') + def test_add_reservation_uses_empty_hostname_when_omitted(self, mock_nm): + mock_nm.add_dhcp_reservation.return_value = True + self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50'}), + content_type='application/json', + ) + mock_nm.add_dhcp_reservation.assert_called_once_with('aa:bb:cc:dd:ee:ff', '192.168.1.50', '') + + @patch('app.network_manager') + def test_add_reservation_returns_500_on_exception(self, mock_nm): + mock_nm.add_dhcp_reservation.side_effect = Exception('dnsmasq config error') + r = self.client.post( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff', 'ip': '192.168.1.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeleteDhcpReservation(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_delete_reservation_returns_200_on_success(self, mock_nm): + mock_nm.remove_dhcp_reservation.return_value = True + r = self.client.delete( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('success', data) + + @patch('app.network_manager') + def test_delete_reservation_returns_400_when_mac_missing(self, mock_nm): + r = self.client.delete( + '/api/dhcp/reservations', + data=json.dumps({'hostname': 'printer'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_nm.remove_dhcp_reservation.assert_not_called() + + @patch('app.network_manager') + def test_delete_reservation_returns_400_when_no_body(self, mock_nm): + r = self.client.delete('/api/dhcp/reservations') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.network_manager') + def test_delete_reservation_returns_500_on_exception(self, mock_nm): + mock_nm.remove_dhcp_reservation.side_effect = Exception('reservation not found') + r = self.client.delete( + '/api/dhcp/reservations', + data=json.dumps({'mac': 'aa:bb:cc:dd:ee:ff'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetNetworkInfo(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_get_network_info_returns_200_with_info_dict(self, mock_nm): + mock_nm.get_network_info.return_value = { + 'interfaces': ['eth0', 'wg0'], + 'gateway': '192.168.1.1', + 'dns': ['127.0.0.1'], + } + r = self.client.get('/api/network/info') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('interfaces', data) + + @patch('app.network_manager') + def test_get_network_info_returns_500_on_exception(self, mock_nm): + mock_nm.get_network_info.side_effect = Exception('network unreachable') + r = self.client.get('/api/network/info') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestNetworkTest(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.network_manager') + def test_network_test_returns_200_with_result(self, mock_nm): + mock_nm.test_connectivity.return_value = { + 'internet': True, + 'dns': True, + 'latency_ms': 15, + } + r = self.client.post('/api/network/test') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('internet', data) + + @patch('app.network_manager') + def test_network_test_returns_500_on_exception(self, mock_nm): + mock_nm.test_connectivity.side_effect = Exception('ping failed') + r = self.client.post('/api/network/test') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_network_manager.py b/tests/test_network_manager.py index 4cf4e2f..40216a7 100644 --- a/tests/test_network_manager.py +++ b/tests/test_network_manager.py @@ -399,11 +399,13 @@ class TestCellDnsForwarding(unittest.TestCase): self.assertNotIn('10.1.0.1', content) @patch('subprocess.run') - def test_remove_nonexistent_forward_is_noop(self, _mock): - before = open(self.corefile).read() - self.nm.remove_cell_dns_forward('nonexistent.cell') + def test_remove_nonexistent_forward_does_not_error(self, _mock): + # Removing a domain that was never added must not raise and must not + # leave the nonexistent domain in the regenerated Corefile. + result = self.nm.remove_cell_dns_forward('nonexistent.cell') after = open(self.corefile).read() - self.assertEqual(before, after) + self.assertNotIn('nonexistent.cell', after) + # The Corefile is regenerated (new canonical format) — that's correct. if __name__ == '__main__': diff --git a/tests/test_peer_management_edge_cases.py b/tests/test_peer_management_edge_cases.py new file mode 100644 index 0000000..2a9d203 --- /dev/null +++ b/tests/test_peer_management_edge_cases.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Edge-case tests for peer management endpoints in api/app.py. + +Key scenarios: +- POST /api/peers with subnet exhaustion (_next_peer_ip raises ValueError) → 409 +- POST /api/peers//clear-reinstall: success (200) +- POST /api/peers//clear-reinstall: unknown peer raises → 500 +- POST /api/ip-update: missing 'peer' field → 400 +- POST /api/ip-update: missing 'ip' field → 400 +- POST /api/ip-update: unknown peer → 404 +- POST /api/ip-update: success → 200 +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestAddPeerSubnetExhaustion(unittest.TestCase): + """POST /api/peers with no free IPs left must return 409, not 500.""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app._next_peer_ip') + @patch('app.auth_manager') + def test_add_peer_returns_409_when_subnet_exhausted(self, mock_auth, mock_next_ip): + mock_auth.create_user.return_value = True + mock_next_ip.side_effect = ValueError('No free IPs left in 10.0.0.0/24') + + r = self.client.post( + '/api/peers', + data=json.dumps({ + 'name': 'newpeer', + 'public_key': 'PUBKEY==', + 'password': 'verysecret123', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 409) + data = json.loads(r.data) + self.assertIn('error', data) + + @patch('app._next_peer_ip') + @patch('app.auth_manager') + def test_add_peer_409_error_message_mentions_ip(self, mock_auth, mock_next_ip): + mock_auth.create_user.return_value = True + mock_next_ip.side_effect = ValueError('No free IPs left in 10.0.0.0/24') + + r = self.client.post( + '/api/peers', + data=json.dumps({ + 'name': 'newpeer', + 'public_key': 'PUBKEY==', + 'password': 'verysecret123', + }), + content_type='application/json', + ) + self.assertEqual(r.status_code, 409) + data = json.loads(r.data) + self.assertIn('No free IPs', data['error']) + + +class TestClearReinstallFlag(unittest.TestCase): + """POST /api/peers//clear-reinstall""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.peer_registry') + def test_clear_reinstall_returns_200_on_success(self, mock_reg): + mock_reg.clear_reinstall_flag.return_value = True + r = self.client.post('/api/peers/alice/clear-reinstall') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.peer_registry') + def test_clear_reinstall_calls_registry_with_peer_name(self, mock_reg): + mock_reg.clear_reinstall_flag.return_value = True + self.client.post('/api/peers/bob/clear-reinstall') + mock_reg.clear_reinstall_flag.assert_called_once_with('bob') + + @patch('app.peer_registry') + def test_clear_reinstall_returns_500_when_exception_raised(self, mock_reg): + mock_reg.clear_reinstall_flag.side_effect = Exception('peer not found') + r = self.client.post('/api/peers/ghost/clear-reinstall') + self.assertEqual(r.status_code, 500) + data = json.loads(r.data) + self.assertIn('error', data) + + +class TestIpUpdate(unittest.TestCase): + """POST /api/ip-update""" + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + @patch('app.peer_registry') + def test_ip_update_returns_200_on_success(self, mock_reg, mock_rm): + mock_reg.update_peer_ip.return_value = True + mock_rm.update_peer_ip.return_value = None + + r = self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'alice', 'ip': '10.0.0.99'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.peer_registry') + def test_ip_update_returns_400_when_peer_field_missing(self, mock_reg): + r = self.client.post( + '/api/ip-update', + data=json.dumps({'ip': '10.0.0.99'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + mock_reg.update_peer_ip.assert_not_called() + + @patch('app.peer_registry') + def test_ip_update_returns_400_when_ip_field_missing(self, mock_reg): + r = self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + data = json.loads(r.data) + self.assertIn('error', data) + mock_reg.update_peer_ip.assert_not_called() + + @patch('app.peer_registry') + def test_ip_update_returns_400_when_no_body(self, mock_reg): + r = self.client.post('/api/ip-update') + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + + @patch('app.peer_registry') + def test_ip_update_returns_404_when_peer_not_found(self, mock_reg): + mock_reg.update_peer_ip.return_value = False + r = self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'ghost', 'ip': '10.0.0.50'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 404) + data = json.loads(r.data) + self.assertIn('error', data) + + @patch('app.routing_manager') + @patch('app.peer_registry') + def test_ip_update_calls_registry_with_correct_args(self, mock_reg, mock_rm): + mock_reg.update_peer_ip.return_value = True + mock_rm.update_peer_ip.return_value = None + + self.client.post( + '/api/ip-update', + data=json.dumps({'peer': 'alice', 'ip': '10.0.0.5'}), + content_type='application/json', + ) + mock_reg.update_peer_ip.assert_called_once_with('alice', '10.0.0.5') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_peer_management_update.py b/tests/test_peer_management_update.py new file mode 100644 index 0000000..a1ac567 --- /dev/null +++ b/tests/test_peer_management_update.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Tests for PUT /api/peers/. + +Key scenarios: +- 404 when peer_registry.get_peer returns None +- 200 on successful update +- config_needs_reinstall=True in response when internet_access changes +- config_needs_reinstall=False (config_changed=False) when only description changes +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestUpdatePeer(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_returns_404_when_peer_not_found(self, mock_reg, mock_fw): + mock_reg.get_peer.return_value = None + r = self.client.put( + '/api/peers/ghost', + data=json.dumps({'description': 'updated'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 404) + data = json.loads(r.data) + self.assertIn('error', data) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_returns_200_on_success(self, mock_reg, mock_fw): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'my laptop'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('message', data) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_config_changed_true_when_internet_access_changes( + self, mock_reg, mock_fw + ): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'internet_access': False}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertTrue(data['config_changed']) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_config_changed_false_when_only_description_changes( + self, mock_reg, mock_fw + ): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'just a label'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertFalse(data['config_changed']) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_returns_500_when_update_fails(self, mock_reg, mock_fw): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = False + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'fail'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + @patch('app.firewall_manager') + @patch('app.peer_registry') + def test_update_peer_config_changed_true_when_ip_changes(self, mock_reg, mock_fw): + existing = { + 'peer': 'alice', + 'ip': '10.0.0.2', + 'internet_access': True, + 'public_key': 'KEY==', + } + mock_reg.get_peer.return_value = existing + mock_reg.update_peer.return_value = True + mock_reg.list_peers.return_value = [existing] + mock_fw.apply_peer_rules.return_value = None + mock_fw.apply_all_dns_rules.return_value = None + + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'ip': '10.0.0.99'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertTrue(data['config_changed']) + + @patch('app.peer_registry') + def test_update_peer_returns_500_on_exception(self, mock_reg): + mock_reg.get_peer.side_effect = Exception('disk error') + r = self.client.put( + '/api/peers/alice', + data=json.dumps({'description': 'test'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_routing_endpoints.py b/tests/test_routing_endpoints.py index cd725ab..1c55d32 100644 --- a/tests/test_routing_endpoints.py +++ b/tests/test_routing_endpoints.py @@ -1 +1,294 @@ -# ... moved and adapted code from test_phase4_endpoints.py ... \ No newline at end of file +#!/usr/bin/env python3 +""" +Unit tests for routing Flask endpoints in api/app.py. + +Covers: + POST /api/routing/peers (peer_name + peer_ip required) + POST /api/routing/exit-nodes (peer_name + peer_ip required) + POST /api/routing/bridge (source_peer + target_peer required) + POST /api/routing/split (network + exit_peer required) + GET /api/routing/peers + DELETE /api/routing/peers/ +""" + +import sys +import json +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock + +api_dir = Path(__file__).parent.parent / 'api' +sys.path.insert(0, str(api_dir)) + +from app import app + + +class TestAddPeerRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_peer_route_returns_200_on_success(self, mock_rm): + mock_rm.add_peer_route.return_value = True + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_name': 'alice', 'peer_ip': '10.0.0.2'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_peer_route_returns_400_when_peer_name_missing(self, mock_rm): + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_ip': '10.0.0.2'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_peer_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_peer_route_returns_400_when_peer_ip_missing(self, mock_rm): + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_name': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_peer_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_peer_route_returns_500_on_exception(self, mock_rm): + mock_rm.add_peer_route.side_effect = Exception('iptables error') + r = self.client.post( + '/api/routing/peers', + data=json.dumps({'peer_name': 'alice', 'peer_ip': '10.0.0.2'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddExitNode(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_exit_node_returns_200_on_success(self, mock_rm): + mock_rm.add_exit_node.return_value = True + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_name': 'gw', 'peer_ip': '10.0.0.5'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_exit_node_returns_400_when_peer_name_missing(self, mock_rm): + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_ip': '10.0.0.5'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_exit_node.assert_not_called() + + @patch('app.routing_manager') + def test_add_exit_node_returns_400_when_peer_ip_missing(self, mock_rm): + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_name': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_exit_node.assert_not_called() + + @patch('app.routing_manager') + def test_add_exit_node_returns_500_on_exception(self, mock_rm): + mock_rm.add_exit_node.side_effect = Exception('routing table full') + r = self.client.post( + '/api/routing/exit-nodes', + data=json.dumps({'peer_name': 'gw', 'peer_ip': '10.0.0.5'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddBridgeRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_bridge_returns_200_on_success(self, mock_rm): + mock_rm.add_bridge_route.return_value = True + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'source_peer': 'alice', 'target_peer': 'bob'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_bridge_returns_400_when_source_peer_missing(self, mock_rm): + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'target_peer': 'bob'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_bridge_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_bridge_returns_400_when_target_peer_missing(self, mock_rm): + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'source_peer': 'alice'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_bridge_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_bridge_returns_500_on_exception(self, mock_rm): + mock_rm.add_bridge_route.side_effect = Exception('bridge setup failed') + r = self.client.post( + '/api/routing/bridge', + data=json.dumps({'source_peer': 'alice', 'target_peer': 'bob'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestAddSplitRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_add_split_returns_200_on_success(self, mock_rm): + mock_rm.add_split_route.return_value = True + r = self.client.post( + '/api/routing/split', + data=json.dumps({'network': '192.168.10.0/24', 'exit_peer': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('added', data) + + @patch('app.routing_manager') + def test_add_split_returns_400_when_network_missing(self, mock_rm): + r = self.client.post( + '/api/routing/split', + data=json.dumps({'exit_peer': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_split_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_split_returns_400_when_exit_peer_missing(self, mock_rm): + r = self.client.post( + '/api/routing/split', + data=json.dumps({'network': '192.168.10.0/24'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 400) + self.assertIn('error', json.loads(r.data)) + mock_rm.add_split_route.assert_not_called() + + @patch('app.routing_manager') + def test_add_split_returns_500_on_exception(self, mock_rm): + mock_rm.add_split_route.side_effect = Exception('split tunnel error') + r = self.client.post( + '/api/routing/split', + data=json.dumps({'network': '192.168.10.0/24', 'exit_peer': 'gw'}), + content_type='application/json', + ) + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestGetPeerRoutes(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_get_peer_routes_returns_200_with_routes(self, mock_rm): + mock_rm.get_peer_routes.return_value = [ + {'peer_name': 'alice', 'peer_ip': '10.0.0.2', 'route_type': 'lan'}, + ] + r = self.client.get('/api/routing/peers') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertIn('peer_routes', data) + self.assertIsInstance(data['peer_routes'], list) + + @patch('app.routing_manager') + def test_get_peer_routes_returns_empty_list_when_no_routes(self, mock_rm): + mock_rm.get_peer_routes.return_value = [] + r = self.client.get('/api/routing/peers') + self.assertEqual(r.status_code, 200) + data = json.loads(r.data) + self.assertEqual(data['peer_routes'], []) + + @patch('app.routing_manager') + def test_get_peer_routes_returns_500_on_exception(self, mock_rm): + mock_rm.get_peer_routes.side_effect = Exception('DB error') + r = self.client.get('/api/routing/peers') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +class TestDeletePeerRoute(unittest.TestCase): + + def setUp(self): + app.config['TESTING'] = True + self.client = app.test_client() + + @patch('app.routing_manager') + def test_delete_peer_route_returns_200_on_success(self, mock_rm): + mock_rm.remove_peer_route.return_value = {'removed': True} + r = self.client.delete('/api/routing/peers/alice') + self.assertEqual(r.status_code, 200) + + @patch('app.routing_manager') + def test_delete_peer_route_calls_manager_with_name(self, mock_rm): + mock_rm.remove_peer_route.return_value = {'removed': True} + self.client.delete('/api/routing/peers/bob') + mock_rm.remove_peer_route.assert_called_once_with('bob') + + @patch('app.routing_manager') + def test_delete_peer_route_returns_500_on_exception(self, mock_rm): + mock_rm.remove_peer_route.side_effect = Exception('iptables flush error') + r = self.client.delete('/api/routing/peers/alice') + self.assertEqual(r.status_code, 500) + self.assertIn('error', json.loads(r.data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/webui/src/pages/Dashboard.jsx b/webui/src/pages/Dashboard.jsx index 77da25d..f626106 100644 --- a/webui/src/pages/Dashboard.jsx +++ b/webui/src/pages/Dashboard.jsx @@ -24,9 +24,9 @@ function Dashboard({ isOnline }) { const { domain = 'cell', cell_name = 'mycell' } = useConfig(); const SERVICES = [ { name: 'Cell Home', url: `http://${cell_name}.${domain}`, desc: 'Main UI — no login needed' }, - { name: 'Calendar', url: `http://calendar.${domain}`, desc: 'Login: your WireGuard username' }, - { name: 'Files', url: `http://files.${domain}`, desc: 'Login: admin / admin123' }, - { name: 'Webmail', url: `http://mail.${domain}`, desc: 'Login: admin@rainloop.net / 12345' }, + { name: 'Calendar', url: `http://calendar.${domain}`, desc: 'Use your configured account credentials' }, + { name: 'Files', url: `http://files.${domain}`, desc: 'Use your configured account credentials' }, + { name: 'Webmail', url: `http://mail.${domain}`, desc: 'Use your configured account credentials' }, ]; const [cellStatus, setCellStatus] = useState(null); const [servicesStatus, setServicesStatus] = useState(null); diff --git a/webui/src/pages/Peers.jsx b/webui/src/pages/Peers.jsx index 1421dee..518ba4f 100644 --- a/webui/src/pages/Peers.jsx +++ b/webui/src/pages/Peers.jsx @@ -191,13 +191,6 @@ PersistentKeepalive = ${peer.persistent_keepalive || 25}`; password: formData.password, }; const addResult = await peerRegistryAPI.addPeer(peerData); - const assignedIp = addResult.data?.ip; - await wireguardAPI.addPeer({ - name: formData.name, - public_key: publicKey, - allowed_ips: assignedIp ? `${assignedIp}/32` : `${peerData.ip}/32`, - persistent_keepalive: formData.persistent_keepalive, - }); if (formData.create_calendar) { try { @@ -268,7 +261,7 @@ PersistentKeepalive = ${peer.persistent_keepalive || 25}`; const handleRemovePeer = async (peerName) => { if (!window.confirm(`Remove peer "${peerName}"?`)) return; try { - await Promise.all([peerRegistryAPI.removePeer(peerName), wireguardAPI.removePeer({ name: peerName })]); + await peerRegistryAPI.removePeer(peerName); fetchPeers(); showToast(`Peer "${peerName}" removed.`); } catch { showToast('Failed to remove peer', 'error'); } diff --git a/webui/src/pages/WireGuard.jsx b/webui/src/pages/WireGuard.jsx index 8fddc19..98a79c3 100644 --- a/webui/src/pages/WireGuard.jsx +++ b/webui/src/pages/WireGuard.jsx @@ -66,26 +66,29 @@ function WireGuard() { const peersData = peersResponse.data || []; const wireguardPeers = wireguardResponse.data || []; - // Create a map of WireGuard peers by name for quick lookup + // Create a map of WireGuard peers by public_key for quick lookup const wireguardMap = {}; wireguardPeers.forEach(peer => { - wireguardMap[peer.name] = peer; + if (peer.public_key) wireguardMap[peer.public_key] = peer; }); - + // Merge the data - const mergedPeers = peersData.map(peer => ({ - ...peer, - ...wireguardMap[peer.peer || peer.name], - name: peer.peer || peer.name, - status: 'Online', // For now, assume all peers are online - type: 'WireGuard', - // Preserve important fields that might be overwritten - private_key: peer.private_key, - server_public_key: peer.server_public_key, - server_endpoint: peer.server_endpoint, - allowed_ips: peer.allowed_ips || wireguardMap[peer.peer || peer.name]?.AllowedIPs || '0.0.0.0/0', - persistent_keepalive: peer.persistent_keepalive || wireguardMap[peer.peer || peer.name]?.PersistentKeepalive || 25 - })); + const mergedPeers = peersData.map(peer => { + const wgEntry = wireguardMap[peer.public_key] || {}; + return { + ...peer, + ...wgEntry, + // Registry fields always win over wg0.conf fields for name/keys/endpoint + name: peer.peer || peer.name, + type: 'WireGuard', + private_key: peer.private_key, + server_public_key: peer.server_public_key, + server_endpoint: peer.server_endpoint, + public_key: peer.public_key, + allowed_ips: peer.allowed_ips || wgEntry.allowed_ips || '0.0.0.0/0', + persistent_keepalive: peer.persistent_keepalive || wgEntry.persistent_keepalive || 25, + }; + }); // Load all peer statuses in one call (keyed by public_key) let liveStatuses = {}; diff --git a/webui/src/services/api.js b/webui/src/services/api.js index e1fa5bd..1a3ad20 100644 --- a/webui/src/services/api.js +++ b/webui/src/services/api.js @@ -1,5 +1,16 @@ import axios from 'axios'; +// Module-level CSRF token — populated after login or token refresh +let _csrfToken = null; + +/** + * Update the module-level CSRF token. + * Call this after a successful login with the token returned in the response body. + */ +export function setCsrfToken(token) { + _csrfToken = token; +} + // Create axios instance with base configuration const api = axios.create({ baseURL: import.meta.env.VITE_API_URL || '', @@ -10,10 +21,16 @@ const api = axios.create({ }, }); -// Request interceptor for logging +// Request interceptor — logging + CSRF header injection api.interceptors.request.use( (config) => { console.log(`API Request: ${config.method?.toUpperCase()} ${config.url}`); + // Attach CSRF token for all state-changing methods + const method = (config.method || 'get').toLowerCase(); + if (['post', 'put', 'delete', 'patch'].includes(method) && _csrfToken) { + config.headers = config.headers || {}; + config.headers['X-CSRF-Token'] = _csrfToken; + } return config; }, (error) => { @@ -22,13 +39,36 @@ api.interceptors.request.use( } ); -// Response interceptor for error handling +// Response interceptor — error handling + CSRF token refresh on 403 api.interceptors.response.use( (response) => { return response; }, - (error) => { + async (error) => { console.error('API Response Error:', error.response?.data || error.message); + + // Handle CSRF token expiry: refresh the token and retry the original request once + if ( + error.response?.status === 403 && + error.response?.data?.error === 'CSRF token missing or invalid' && + !error.config._csrfRetry + ) { + try { + const refreshResp = await api.get('/api/auth/csrf-token'); + const newToken = refreshResp.data?.csrf_token; + if (newToken) { + setCsrfToken(newToken); + // Retry the original request with the new token + const retryConfig = { ...error.config, _csrfRetry: true }; + retryConfig.headers = retryConfig.headers || {}; + retryConfig.headers['X-CSRF-Token'] = newToken; + return api(retryConfig); + } + } catch (refreshErr) { + console.error('CSRF token refresh failed:', refreshErr); + } + } + if ( error.response?.status === 401 && !error.config.url.includes('/auth/login') && @@ -107,12 +147,19 @@ export const peerRegistryAPI = { // Auth API export const authAPI = { - login: (username, password) => api.post('/api/auth/login', { username, password }), + login: async (username, password) => { + const response = await api.post('/api/auth/login', { username, password }); + if (response.data?.csrf_token) { + setCsrfToken(response.data.csrf_token); + } + return response; + }, logout: () => api.post('/api/auth/logout'), me: () => api.get('/api/auth/me'), changePassword: (old_password, new_password) => api.post('/api/auth/change-password', { old_password, new_password }), adminResetPassword: (username, new_password) => api.post('/api/auth/admin/reset-password', { username, new_password }), listUsers: () => api.get('/api/auth/users'), + getCsrfToken: () => api.get('/api/auth/csrf-token'), }; // Peer-facing dashboard API