diff --git a/api/app.py b/api/app.py index 5d0303c..39d7d1f 100644 --- a/api/app.py +++ b/api/app.py @@ -907,13 +907,18 @@ def connectivity_apply_routes(): @app.route('/api/connectivity/peers//exit', methods=['PUT']) def connectivity_set_peer_exit(peer_name: str): - """Assign a peer to an egress exit type.""" + """Assign a peer to a connection by id (or 'default' to clear). + + Body: {"connection_id": "|default"}. The legacy {"exit_via": ""} + field is still accepted as a one-release back-compat shim and resolved to + the single connection instance of that type. + """ try: data = request.get_json(silent=True) or {} - exit_via = data.get('exit_via') - if not isinstance(exit_via, str): - return jsonify({'ok': False, 'error': 'exit_via is required'}), 400 - result = connectivity_manager.set_peer_exit(peer_name, exit_via) + connection_id = data.get('connection_id', data.get('exit_via')) + if not isinstance(connection_id, str) or not connection_id: + return jsonify({'ok': False, 'error': 'connection_id is required'}), 400 + result = connectivity_manager.set_peer_exit(peer_name, connection_id) if result.get('ok'): return jsonify(result) return jsonify(result), 400 @@ -995,13 +1000,18 @@ def egress_status(): @app.route('/api/egress/services//exit', methods=['PUT']) def egress_set_service_exit(service_id: str): - """Persist and immediately apply a per-service egress override.""" + """Persist and immediately apply a per-service egress override. + + Body: {"connection_id": "|default"}. The legacy {"exit_type": ""} + field is still accepted as a one-release back-compat shim and resolved to + the single connection instance of that type. + """ try: data = request.get_json(silent=True) or {} - exit_type = data.get('exit_type') - if not isinstance(exit_type, str): - return jsonify({'ok': False, 'error': 'exit_type is required'}), 400 - result = egress_manager.set_service_exit(service_id, exit_type) + connection_id = data.get('connection_id', data.get('exit_type')) + if not isinstance(connection_id, str) or not connection_id: + return jsonify({'ok': False, 'error': 'connection_id is required'}), 400 + result = egress_manager.set_service_exit(service_id, connection_id) if result.get('ok'): return jsonify(result) return jsonify(result), 400 diff --git a/api/connectivity_manager.py b/api/connectivity_manager.py index 3632d4d..1fb278d 100644 --- a/api/connectivity_manager.py +++ b/api/connectivity_manager.py @@ -357,14 +357,16 @@ class ConnectivityManager(BaseServiceManager): logger.warning(f"get_peer_exits: {e}") return out - def set_peer_exit(self, peer_name: str, exit_type: str) -> Dict[str, Any]: - """Assign a peer to an egress path and apply the rule changes.""" - if exit_type not in self.EXIT_TYPES: - return { - 'ok': False, - 'error': f"invalid exit_type {exit_type!r}; " - f"must be one of {self.EXIT_TYPES}", - } + def set_peer_exit(self, peer_name: str, exit_via: str) -> Dict[str, Any]: + """Assign a peer to a connection (by id) or 'default' and apply rules. + + `exit_via` is a connection id, 'default', or — as a one-release + back-compat shim — a legacy exit *type* string, which is resolved to + the single connection instance of that type. Validation that the id + exists lives in peer_registry.set_peer_exit_via. + """ + if not isinstance(exit_via, str) or not exit_via: + return {'ok': False, 'error': 'connection_id is required'} if not isinstance(peer_name, str) or not re.match(r'^[A-Za-z0-9_.-]{1,64}$', peer_name): return {'ok': False, 'error': f'invalid peer_name {peer_name!r}'} @@ -372,11 +374,16 @@ class ConnectivityManager(BaseServiceManager): return {'ok': False, 'error': 'peer_registry not available'} try: - ok = self.peer_registry.set_peer_exit_via(peer_name, exit_type) + ok = self.peer_registry.set_peer_exit_via(peer_name, exit_via) except Exception as e: logger.error(f"set_peer_exit: registry update failed: {e}") return {'ok': False, 'error': str(e)} if not ok: + # Distinguish "no such peer" from "no such connection". + if self._peer_exists(peer_name): + return {'ok': False, 'error': + f'unknown connection {exit_via!r}; ' + f"must be a connection id or 'default'"} return {'ok': False, 'error': f'peer {peer_name!r} not found'} try: @@ -384,7 +391,23 @@ class ConnectivityManager(BaseServiceManager): except Exception as e: logger.warning(f"set_peer_exit: apply_routes failed (non-fatal): {e}") - return {'ok': True, 'peer': peer_name, 'exit_via': exit_type} + resolved = 'default' + try: + peer = self.peer_registry.get_peer(peer_name) + if peer: + resolved = peer.get('exit_via', 'default') + except Exception: + pass + return {'ok': True, 'peer': peer_name, 'exit_via': resolved} + + def _peer_exists(self, peer_name: str) -> bool: + """True when a peer with this name is registered.""" + if self.peer_registry is None: + return False + try: + return self.peer_registry.get_peer(peer_name) is not None + except Exception: + return False def upload_wireguard_ext(self, conf_text: str) -> Dict[str, Any]: """Validate and store an external WireGuard config.""" @@ -1121,17 +1144,26 @@ class ConnectivityManager(BaseServiceManager): def _connection_reference(self, conn_id: str) -> Optional[str]: """Return a human description if a peer/egress references this connection. - Phase 2 wires peers/egress to connection ids; until then nothing - references a connection, so this returns None. Kept as the single - choke-point so phase 2 only has to fill in the lookups here. + A peer references a connection through its exit_via field (a connection + id); a service references one through the egress_overrides map. Either + blocks deletion until the reference is detached. """ if self.peer_registry is not None: try: for peer in self.peer_registry.list_peers(): - if peer.get('connection_id') == conn_id: + if peer.get('exit_via') == conn_id: return f"peer {peer.get('peer')!r}" except Exception as e: logger.debug(f"_connection_reference (peers): {e}") + if self.config_manager is not None: + try: + overrides = self.config_manager.configs.get('egress_overrides') + if isinstance(overrides, dict): + for svc_id, cid in overrides.items(): + if cid == conn_id: + return f"service {svc_id!r}" + except Exception as e: + logger.debug(f"_connection_reference (egress): {e}") return None def list_connections(self) -> List[Dict[str, Any]]: @@ -1303,7 +1335,14 @@ class ConnectivityManager(BaseServiceManager): # ── Routing application ─────────────────────────────────────────────── def apply_routes(self) -> Dict[str, Any]: - """Idempotently rebuild all connectivity rules and policy routing.""" + """Idempotently rebuild all connectivity rules and policy routing. + + Connectivity v2: routing is driven by connection *instances*, not by + per-type constants. Each connection carries its own persisted mark, + table, iface and redirect_port; two instances of the same type route + through distinct tables/marks without collision. A peer's exit_via is + the id of the connection it egresses through. + """ rules_applied = 0 try: @@ -1319,18 +1358,23 @@ class ConnectivityManager(BaseServiceManager): except Exception as e: logger.warning(f"apply_routes: flush {table}/{chain} failed: {e}") - # Idempotent ip rule registration for each non-default exit - for exit_type in self.MARKS: - mark = self.MARKS[exit_type] - table = self.TABLES[exit_type] + connections = self._routing_connections() + + # Idempotent ip rule registration: one fwmark→table rule per instance. + for conn in connections: + mark, table = conn.get('mark'), conn.get('table') + if not isinstance(mark, int) or not isinstance(table, int): + continue try: self._remove_ip_rule(mark, table) self._add_ip_rule(mark, table) rules_applied += 1 except Exception as e: - logger.warning(f"apply_routes: ip rule {exit_type} failed: {e}") + logger.warning( + f"apply_routes: ip rule {conn.get('id')} failed: {e}") - # Per-peer marking + nat redirect (Tor only) + # Per-peer marking + nat redirect, resolved through each peer's + # connection instance. if self.peer_registry is not None: try: peers = self.peer_registry.list_peers() @@ -1338,45 +1382,82 @@ class ConnectivityManager(BaseServiceManager): logger.warning(f"apply_routes: list_peers failed: {e}") peers = [] + by_id = {c.get('id'): c for c in connections} for peer in peers: - exit_via = peer.get('exit_via', 'default') - if exit_via == 'default' or exit_via not in self.MARKS: + conn = self._resolve_peer_connection(peer, by_id) + if conn is None: continue src_ip = self._peer_source_ip(peer.get('peer', '')) if not src_ip: continue - mark = self.MARKS[exit_via] - try: - self._add_mark_rule(src_ip, mark) - rules_applied += 1 - except Exception as e: - logger.warning( - f"apply_routes: mark rule for {src_ip}/{exit_via}: {e}" - ) - - # Tor / sshuttle / proxy: redirect TCP to the local - # transparent-proxy port for that exit. - if exit_via in self.REDIRECT_PORTS: - try: - self._add_redirect(src_ip, self.REDIRECT_PORTS[exit_via]) - rules_applied += 1 - except Exception as e: - logger.warning( - f"apply_routes: {exit_via} redirect for {src_ip}: {e}" - ) + rules_applied += self._apply_connection_for_src(src_ip, conn) # Kill-switch: drop marked packets that would otherwise leak via the - # default route if the exit interface is down. - for exit_type, iface in self.IFACES.items(): - mark = self.MARKS[exit_type] + # default route if an iface-based exit interface is down. + for conn in connections: + iface = conn.get('iface') + mark = conn.get('mark') + if not iface or not isinstance(mark, int): + continue try: self._add_killswitch(mark, iface) rules_applied += 1 except Exception as e: - logger.warning(f"apply_routes: killswitch {exit_type}: {e}") + logger.warning( + f"apply_routes: killswitch {conn.get('id')}: {e}") return {'ok': True, 'rules_applied': rules_applied} + def _routing_connections(self) -> List[Dict[str, Any]]: + """Return the connection instances that drive routing (enabled only).""" + if self.config_manager is None: + return [] + try: + conns = self.config_manager.list_connections() + except Exception as e: + logger.warning(f"apply_routes: list_connections failed: {e}") + return [] + return [c for c in conns if c.get('enabled', True)] + + @staticmethod + def _resolve_peer_connection( + peer: Dict[str, Any], by_id: Dict[str, Dict[str, Any]], + ) -> Optional[Dict[str, Any]]: + """Resolve a peer's exit_via (a connection id) to its connection record.""" + exit_via = peer.get('exit_via', 'default') + if exit_via == 'default': + return None + return by_id.get(exit_via) + + def _apply_connection_for_src( + self, src_ip: str, conn: Dict[str, Any], + ) -> int: + """Mark + optionally REDIRECT traffic from src_ip via this connection. + + Returns the number of rules applied. iface-based connections only need + the fwmark (policy route + killswitch handle egress); redirect-style + connections additionally REDIRECT TCP to the instance's redirect_port. + """ + applied = 0 + mark = conn.get('mark') + if isinstance(mark, int): + try: + self._add_mark_rule(src_ip, mark) + applied += 1 + except Exception as e: + logger.warning( + f"apply_routes: mark rule for {src_ip}/{conn.get('id')}: {e}") + + redirect_port = conn.get('redirect_port') + if conn.get('type') in self.REDIRECT_TYPES and isinstance(redirect_port, int): + try: + self._add_redirect(src_ip, redirect_port) + applied += 1 + except Exception as e: + logger.warning( + f"apply_routes: redirect for {src_ip}/{conn.get('id')}: {e}") + return applied + # ── iptables / ip rule helpers ──────────────────────────────────────── def _wg_iptables(self, args: List[str], timeout: int = 10) -> subprocess.CompletedProcess: diff --git a/api/egress_manager.py b/api/egress_manager.py index 8ca8389..dc8f519 100644 --- a/api/egress_manager.py +++ b/api/egress_manager.py @@ -9,8 +9,13 @@ for install/remove lifecycle hooks. Rules live on the HOST in PIC_EGRESS chains in the mangle and nat tables. Container IPs are discovered via docker inspect using the -container_name from the service manifest. Marks are distinct from -ConnectivityManager to prevent rule collisions. +container_name from the service manifest. + +Connectivity v2: a service routes through a *connection instance* (by id), +sharing the same fwmark / routing table / redirect port as any peer that +egresses through the same connection. The (mark, table, redirect_port) for a +service are resolved from ConnectivityManager.get_connection(id) — EgressManager +no longer owns its own per-type MARKS/TABLES tables. """ import logging import subprocess @@ -19,34 +24,18 @@ from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) -EXIT_TYPES = ("default", "wireguard_ext", "openvpn", "tor", "sshuttle", "proxy") - -# fwmark values — must not collide with ConnectivityManager -# (0x10, 0x20, 0x30, 0x40, 0x50) -MARKS = {"wireguard_ext": 0x110, "openvpn": 0x120, "tor": 0x130, - "sshuttle": 0x140, "proxy": 0x150} - -# Policy routing table IDs -TABLES = {"wireguard_ext": 210, "openvpn": 220, "tor": 230, - "sshuttle": 240, "proxy": 250} - EGRESS_CHAIN = "PIC_EGRESS" -# Transparent proxy port used by Tor -_TOR_TRANS_PORT = 9040 - -# Local transparent-proxy ports for redirect-style exits (no exit iface): -# traffic is REDIRECTed to the listener of the corresponding exit container. -_REDIRECT_PORTS = {"tor": _TOR_TRANS_PORT, "sshuttle": 12300, "proxy": 12345} - class EgressManager: """Per-service egress enforcement via host iptables fwmark policy-routing.""" def __init__(self, config_manager, service_store_manager=None, + connectivity_manager=None, data_dir: str = "/app/data", config_dir: str = "/app/config"): self.config_manager = config_manager self.service_store_manager = service_store_manager + self.connectivity_manager = connectivity_manager self._data_dir = data_dir self._config_dir = config_dir @@ -60,9 +49,10 @@ class EgressManager: 2. clear_service first (ensures idempotency). 3. If the manifest has no egress block, skip silently. 4. Discover the container IP. - 5. Resolve the exit type (override > manifest default > 'default'). - 6. If exit is 'default', return early with no rules. - 7. Otherwise create chains, ensure ip rules, add mark rules. + 5. Resolve the connection id (override > manifest default > 'default'). + 6. If 'default', return early with no rules. + 7. Otherwise resolve the connection's (mark, table, redirect_port), + create chains, ensure ip rules, add mark/redirect rules. """ manifest = self._get_manifest(service_id) if manifest is None: @@ -79,36 +69,39 @@ class EgressManager: if not container_ip: return {'ok': False, 'error': 'container IP not discoverable'} - exit_via = self._resolve_exit(service_id, manifest) + connection_id = self._resolve_exit(service_id, manifest) - # Validate exit_via is a known, non-default value - if exit_via not in EXIT_TYPES: - return { - 'ok': False, - 'error': f'unknown exit_via {exit_via!r}; must be one of {EXIT_TYPES}', - } - - if exit_via == 'default': + if connection_id == 'default': return {'ok': True, 'exit_via': 'default'} - if exit_via not in MARKS: + conn = self._get_connection(connection_id) + if conn is None: return { 'ok': False, - 'error': f'unknown exit_via {exit_via!r}; must be one of {EXIT_TYPES}', + 'error': f'unknown connection {connection_id!r}', + } + + mark = conn.get('mark') + table = conn.get('table') + if not isinstance(mark, int) or not isinstance(table, int): + return { + 'ok': False, + 'error': f'connection {connection_id!r} has no routing resources', } try: self._ensure_chains() - self._ensure_host_ip_rules() - self._add_mark_rule(container_ip, MARKS[exit_via], service_id) - if exit_via in _REDIRECT_PORTS: - self._add_redirect(container_ip, _REDIRECT_PORTS[exit_via], - service_id) + self._ensure_host_ip_rule(mark, table) + self._add_mark_rule(container_ip, mark, service_id) + redirect_port = conn.get('redirect_port') + if isinstance(redirect_port, int): + self._add_redirect(container_ip, redirect_port, service_id) except Exception as exc: logger.error('apply_service(%s): %s', service_id, exc) return {'ok': False, 'error': str(exc)} - return {'ok': True, 'exit_via': exit_via, 'container_ip': container_ip} + return {'ok': True, 'exit_via': connection_id, + 'container_ip': container_ip} def clear_service(self, service_id: str) -> Dict[str, Any]: """Remove all PIC_EGRESS rules tagged for this service.""" @@ -129,10 +122,13 @@ class EgressManager: results[svc_id] = self.apply_service(svc_id) return {'ok': True, 'services': results} - def set_service_exit(self, service_id: str, exit_type: str) -> Dict[str, Any]: - """Persist a per-service egress override and immediately reapply rules. + def set_service_exit(self, service_id: str, connection_id: str) -> Dict[str, Any]: + """Persist a per-service egress override (by connection id) and reapply. - exit_type must appear in the manifest's egress.allowed list. + `connection_id` is a real connection id or 'default'. A legacy exit + *type* string is accepted as a one-release back-compat shim and resolved + to the single connection instance of that type. The resolved + connection's type must be in the manifest's egress.allowed list. """ manifest = self._get_manifest(service_id) if manifest is None: @@ -141,31 +137,91 @@ class EgressManager: if not self._has_egress(manifest): return {'ok': False, 'error': f'service {service_id!r} has no egress configuration'} + if connection_id == 'default': + overrides = self._get_egress_overrides() + overrides[service_id] = 'default' + self._set_egress_overrides(overrides) + return self.apply_service(service_id) + + resolved = self._resolve_connection_id(connection_id) + if resolved is None: + return { + 'ok': False, + 'error': f"unknown connection {connection_id!r}; " + f"must be a connection id or 'default'", + } + + conn = self._get_connection(resolved) egress = manifest.get('egress', {}) - allowed = egress.get('allowed', list(EXIT_TYPES)) - - if exit_type not in allowed: - return { - 'ok': False, - 'error': ( - f'exit_type {exit_type!r} is not in the allowed list ' - f'for {service_id}: {allowed}' - ), - } - - if exit_type not in EXIT_TYPES: - return { - 'ok': False, - 'error': f'unknown exit_type {exit_type!r}; must be one of {EXIT_TYPES}', - } + allowed = egress.get('allowed') + if isinstance(allowed, list) and conn is not None: + if conn.get('type') not in allowed: + return { + 'ok': False, + 'error': ( + f"connection type {conn.get('type')!r} is not in the " + f'allowed list for {service_id}: {allowed}' + ), + } # Persist the override so it survives restarts overrides = self._get_egress_overrides() - overrides[service_id] = exit_type + overrides[service_id] = resolved self._set_egress_overrides(overrides) return self.apply_service(service_id) + def _connections(self) -> List[dict]: + """Return the v2 connection records, or [] when unavailable.""" + if self.connectivity_manager is not None: + try: + conns = self.connectivity_manager.list_connections() + return conns if isinstance(conns, list) else [] + except Exception as exc: + logger.warning('egress: list_connections failed: %s', exc) + return [] + if self.config_manager is not None: + try: + conns = self.config_manager.list_connections() + return conns if isinstance(conns, list) else [] + except Exception as exc: + logger.warning('egress: list_connections failed: %s', exc) + return [] + + def _get_connection(self, connection_id: str) -> Optional[dict]: + """Resolve a connection record (with mark/table/redirect_port) by id.""" + if self.connectivity_manager is not None: + try: + return self.connectivity_manager.get_connection(connection_id) + except Exception as exc: + logger.warning('egress: get_connection failed: %s', exc) + return None + if self.config_manager is not None: + try: + return self.config_manager.get_connection(connection_id) + except Exception as exc: + logger.warning('egress: get_connection failed: %s', exc) + return None + + _LEGACY_EXIT_TYPES = ('wireguard_ext', 'openvpn', 'tor', 'sshuttle', 'proxy') + + def _resolve_connection_id(self, value: str) -> Optional[str]: + """Resolve a value to a valid connection id. + + Accepts a real connection id, or — as a back-compat shim — a legacy + type string resolved to the single instance of that type. Returns None + when nothing matches. + """ + conns = self._connections() + for c in conns: + if c.get('id') == value: + return value + if value in self._LEGACY_EXIT_TYPES: + matches = [c for c in conns if c.get('type') == value] + if len(matches) == 1: + return matches[0].get('id') + return None + def get_status(self) -> Dict[str, Any]: """Return egress status for every installed service that has egress config.""" installed = self.config_manager.get_installed_services() @@ -201,15 +257,26 @@ class EgressManager: return bool(manifest.get('has_egress', False) and manifest.get('egress')) def _resolve_exit(self, service_id: str, manifest: dict) -> str: - """Determine the effective exit for a service. + """Determine the effective connection id for a service. Priority: persisted override > manifest egress.default > 'default'. + Legacy type strings (from old overrides or a manifest default) are + resolved to the single connection instance of that type; if that can't + be resolved the service falls back to 'default'. """ overrides = self._get_egress_overrides() if service_id in overrides: - return overrides[service_id] - egress = manifest.get('egress') or {} - return egress.get('default', 'default') + value = overrides[service_id] + else: + egress = manifest.get('egress') or {} + value = egress.get('default', 'default') + + if value == 'default': + return 'default' + if value in self._LEGACY_EXIT_TYPES: + resolved = self._resolve_connection_id(value) + return resolved if resolved is not None else 'default' + return value def _discover_container_ip(self, container_name: str, retries: int = 5, delay: float = 0.2) -> Optional[str]: @@ -254,16 +321,18 @@ class EgressManager: ['-t', table, '-I', 'PREROUTING', '1', '-j', EGRESS_CHAIN] ) - def _ensure_host_ip_rules(self) -> None: - """Ensure `ip rule fwmark lookup ` exists for each exit.""" - for exit_type, mark in MARKS.items(): - table = TABLES[exit_type] - # Remove any existing duplicate rules first, then add once - for _ in range(8): - r = self._ip_rule(['del', 'fwmark', hex(mark), 'lookup', str(table)]) - if r.returncode != 0: - break - self._ip_rule(['add', 'fwmark', hex(mark), 'lookup', str(table)]) + def _ensure_host_ip_rule(self, mark: int, table: int) -> None: + """Ensure a single `ip rule fwmark lookup
` exists. + + Idempotent: drains any duplicate rules first, then adds exactly one. + The mark/table belong to the connection instance the service routes + through, so a peer and a service on the same connection share the rule. + """ + for _ in range(8): + r = self._ip_rule(['del', 'fwmark', hex(mark), 'lookup', str(table)]) + if r.returncode != 0: + break + self._ip_rule(['add', 'fwmark', hex(mark), 'lookup', str(table)]) def _add_mark_rule(self, service_ip: str, mark: int, service_id: str) -> None: """Mark outbound packets from the service container with fwmark.""" @@ -283,10 +352,6 @@ class EgressManager: '-m', 'comment', '--comment', self._tag(service_id), ]) - def _add_tor_redirect(self, service_ip: str, service_id: str) -> None: - """Redirect the service container's TCP traffic to the local Tor TransPort.""" - self._add_redirect(service_ip, _TOR_TRANS_PORT, service_id) - def _clear_egress_rules(self, service_id: str) -> None: """Remove all rules tagged pic-egr- from mangle and nat.""" import re as _re diff --git a/api/managers.py b/api/managers.py index 57c7951..1e312a1 100644 --- a/api/managers.py +++ b/api/managers.py @@ -53,7 +53,8 @@ service_registry = ServiceRegistry(config_manager=config_manager) network_manager = NetworkManager(data_dir=DATA_DIR, config_dir=CONFIG_DIR, service_registry=service_registry) wireguard_manager = WireGuardManager(data_dir=DATA_DIR, config_dir=CONFIG_DIR) -peer_registry = PeerRegistry(data_dir=DATA_DIR, config_dir=CONFIG_DIR) +peer_registry = PeerRegistry(data_dir=DATA_DIR, config_dir=CONFIG_DIR, + config_manager=config_manager) email_manager = EmailManager(data_dir=DATA_DIR, config_dir=CONFIG_DIR, service_bus=service_bus) calendar_manager = CalendarManager(data_dir=DATA_DIR, config_dir=CONFIG_DIR) file_manager = FileManager(data_dir=DATA_DIR, config_dir=CONFIG_DIR) @@ -102,6 +103,7 @@ from egress_manager import EgressManager egress_manager = EgressManager( config_manager=config_manager, service_store_manager=service_store_manager, + connectivity_manager=connectivity_manager, data_dir=DATA_DIR, config_dir=CONFIG_DIR, ) diff --git a/api/peer_registry.py b/api/peer_registry.py index c62ef0c..914c70d 100644 --- a/api/peer_registry.py +++ b/api/peer_registry.py @@ -17,11 +17,17 @@ logger = logging.getLogger(__name__) class PeerRegistry(BaseServiceManager): """Manages peer registration and management""" - def __init__(self, data_dir: str = '/app/data', config_dir: str = '/app/config'): + def __init__(self, data_dir: str = '/app/data', config_dir: str = '/app/config', + config_manager=None): super().__init__('peer_registry', data_dir, config_dir) self.lock = RLock() self.peers = [] self.peers_file = os.path.join(data_dir, 'peers.json') + # config_manager is used to resolve/validate connection ids for the + # per-peer exit (exit_via). It may be wired after construction (the + # singletons in managers.py are built in dependency order), so the + # exit_via→connection-id migration also runs lazily, idempotently. + self.config_manager = config_manager self._load_peers() def get_status(self) -> Dict[str, Any]: @@ -205,6 +211,11 @@ class PeerRegistry(BaseServiceManager): changed = True if changed: self._save_peers() + # Phase 2 (connectivity v2): exit_via is now a connection id (or + # 'default'). Rewrite any legacy per-type exit_via to the id of + # the single migrated connection instance of that type. Runs + # lazily if config_manager is not yet wired. + self._migrate_exit_via_to_connection_id() else: self.peers = [] self.logger.info("No peers file found, starting with empty registry") @@ -350,26 +361,101 @@ class PeerRegistry(BaseServiceManager): return dict(peer) raise ValueError(f"Peer '{peer_name}' not found") - # Phase 5: extended connectivity per-peer egress exit - VALID_EXIT_VIA = ('default', 'wireguard_ext', 'openvpn', 'tor', - 'sshuttle', 'proxy') + # Connectivity v2: legacy per-type exit values. A peer's exit_via is now a + # connection id (or 'default'); these strings are accepted only as a + # one-release back-compat shim — resolved to the single migrated instance + # of that type via config_manager.list_connections(). + _LEGACY_EXIT_TYPES = ('wireguard_ext', 'openvpn', 'tor', 'sshuttle', 'proxy') + + def _connections(self) -> List[Dict[str, Any]]: + """Return the v2 connection records, or [] when unavailable.""" + if self.config_manager is None: + return [] + try: + conns = self.config_manager.list_connections() + except Exception as e: + self.logger.warning(f"peer_registry: list_connections failed: {e}") + return [] + return conns if isinstance(conns, list) else [] + + def _resolve_exit_via(self, value: str) -> Optional[str]: + """Resolve an exit_via value to a valid connection id or 'default'. + + Accepts 'default', a real connection id, or — as a back-compat shim — + a legacy type string (resolved to the single instance of that type). + Returns None when the value cannot be resolved to anything valid. + """ + if value == 'default': + return 'default' + conns = self._connections() + for c in conns: + if c.get('id') == value: + return value + if value in self._LEGACY_EXIT_TYPES: + matches = [c for c in conns if c.get('type') == value] + if len(matches) == 1: + return matches[0].get('id') + return None + + def _migrate_exit_via_to_connection_id(self) -> bool: + """Rewrite legacy per-type exit_via values to migrated connection ids. + + Idempotent: ids and 'default' are left untouched. Legacy type strings + are mapped to the single instance of that type; if no instance exists + the peer falls back to 'default'. Returns True if anything changed. + Runs only when config_manager (and its v2 connections) are available. + """ + if self.config_manager is None: + return False + conns = self._connections() + valid_ids = {c.get('id') for c in conns} + by_type: Dict[str, List[str]] = {} + for c in conns: + by_type.setdefault(c.get('type'), []).append(c.get('id')) + + changed = False + with self.lock: + for peer in self.peers: + exit_via = peer.get('exit_via', 'default') + if exit_via == 'default' or exit_via in valid_ids: + continue + new_value = 'default' + if exit_via in self._LEGACY_EXIT_TYPES: + ids = by_type.get(exit_via, []) + if len(ids) == 1: + new_value = ids[0] + peer['exit_via'] = new_value + changed = True + self.logger.info( + f"peer_registry: migrated exit_via {exit_via!r} → " + f"{new_value!r} for {peer.get('peer')!r}" + ) + if changed: + self._save_peers() + return changed def set_peer_exit_via(self, peer_name: str, exit_type: str) -> bool: - """Set the per-peer egress exit type. Returns True if updated, False - if the peer is not found (logged as warning, no exception).""" - if exit_type not in self.VALID_EXIT_VIA: + """Set the per-peer egress connection id. Returns True if updated, False + if the peer is not found or the id is invalid (logged, no exception). + + `exit_type` must be a real connection id or 'default'. A legacy type + string is accepted as a back-compat shim and resolved to the single + instance of that type. + """ + resolved = self._resolve_exit_via(exit_type) + if resolved is None: self.logger.warning( - f"set_peer_exit_via: invalid exit_type {exit_type!r}" + f"set_peer_exit_via: invalid connection id {exit_type!r}" ) return False with self.lock: for peer in self.peers: if peer.get('peer') == peer_name: - peer['exit_via'] = exit_type + peer['exit_via'] = resolved peer['updated_at'] = datetime.utcnow().isoformat() self._save_peers() self.logger.info( - f"Set exit_via for {peer_name}: {exit_type!r}" + f"Set exit_via for {peer_name}: {resolved!r}" ) return True self.logger.warning( diff --git a/tests/test_connectivity_connections.py b/tests/test_connectivity_connections.py index 4647d6c..126968d 100644 --- a/tests/test_connectivity_connections.py +++ b/tests/test_connectivity_connections.py @@ -276,7 +276,7 @@ class TestDeleteConnection(_Base): res = self.mgr.create_connection('proxy', 'ref', _proxy_cfg()) cid = res['connection']['id'] self.peer_registry.list_peers.return_value = [ - {'peer': 'alice', 'connection_id': cid}] + {'peer': 'alice', 'exit_via': cid}] out = self.mgr.delete_connection(cid) self.assertFalse(out['ok']) self.assertIn('in use', out['error']) diff --git a/tests/test_connectivity_manager.py b/tests/test_connectivity_manager.py index a3fb21f..de763a5 100644 --- a/tests/test_connectivity_manager.py +++ b/tests/test_connectivity_manager.py @@ -42,6 +42,7 @@ def _make_manager(tmp_dir=None, peer_registry=_SENTINEL, config_manager=None): 'cell_name': 'test', 'ip_range': '172.20.0.0/16', } + config_manager.list_connections.return_value = [] if peer_registry is _SENTINEL: peer_registry = MagicMock() @@ -530,49 +531,58 @@ class TestSetPeerExit(unittest.TestCase): peer_registry.list_peers.return_value = [] return _make_manager(tmp_dir=self.tmp, peer_registry=peer_registry) - def test_valid_exit_type_returns_ok_true(self): + def test_valid_connection_id_returns_ok_true(self): mgr = self._mgr() with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() - result = mgr.set_peer_exit('alice', 'wireguard_ext') + result = mgr.set_peer_exit('alice', 'conn_abcd') self.assertTrue(result['ok']) - def test_valid_exit_type_default_returns_ok_true(self): + def test_default_returns_ok_true(self): mgr = self._mgr() with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() result = mgr.set_peer_exit('alice', 'default') self.assertTrue(result['ok']) - def test_invalid_exit_type_returns_ok_false(self): + def test_empty_connection_id_returns_ok_false(self): mgr = self._mgr() - result = mgr.set_peer_exit('alice', 'shadowsocks') + result = mgr.set_peer_exit('alice', '') self.assertFalse(result['ok']) self.assertIn('error', result) - def test_invalid_exit_type_error_mentions_type(self): - mgr = self._mgr() - result = mgr.set_peer_exit('alice', 'bad_type') - self.assertIn('bad_type', result['error']) + def test_unknown_connection_for_existing_peer_returns_ok_false(self): + """When the peer exists but the connection id is rejected by the + registry, set_peer_exit reports an unknown-connection error.""" + pr = MagicMock() + pr.set_peer_exit_via.return_value = False + pr.get_peer.return_value = {'peer': 'alice', 'ip': '10.0.0.5'} + pr.list_peers.return_value = [] + mgr = self._mgr(peer_registry=pr) + result = mgr.set_peer_exit('alice', 'conn_ghost') + self.assertFalse(result['ok']) + self.assertIn('unknown connection', result['error']) def test_calls_peer_registry_set_peer_exit_via_with_correct_args(self): pr = MagicMock() pr.set_peer_exit_via.return_value = True pr.list_peers.return_value = [] + pr.get_peer.return_value = {'peer': 'bob', 'exit_via': 'conn_xyz'} mgr = self._mgr(peer_registry=pr) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() - mgr.set_peer_exit('bob', 'openvpn') - pr.set_peer_exit_via.assert_called_once_with('bob', 'openvpn') + mgr.set_peer_exit('bob', 'conn_xyz') + pr.set_peer_exit_via.assert_called_once_with('bob', 'conn_xyz') def test_peer_not_found_in_registry_returns_ok_false(self): pr = MagicMock() pr.set_peer_exit_via.return_value = False # peer not found + pr.get_peer.return_value = None # peer truly absent pr.list_peers.return_value = [] mgr = self._mgr(peer_registry=pr) - result = mgr.set_peer_exit('unknown-peer', 'tor') + result = mgr.set_peer_exit('unknown-peer', 'conn_tor') self.assertFalse(result['ok']) - self.assertIn('error', result) + self.assertIn('not found', result['error']) def test_invalid_peer_name_returns_ok_false(self): mgr = self._mgr() @@ -706,14 +716,22 @@ class TestApplyRoutes(unittest.TestCase): result = mgr.apply_routes() self.assertIsInstance(result, dict) - def test_peer_with_wireguard_ext_exit_generates_mark_rule(self): - """Peers with a non-default exit should trigger _add_mark_rule calls.""" + def _cm_with_connections(self, connections): + cm = MagicMock() + cm.get_identity.return_value = {'cell_name': 't', 'ip_range': '172.20.0.0/16'} + cm.list_connections.return_value = connections + return cm + + def test_peer_with_connection_exit_generates_mark_rule(self): + """A peer whose exit_via is a connection id gets that connection's mark.""" + conns = [{'id': 'conn_wg', 'type': 'wireguard_ext', 'enabled': True, + 'mark': 0x1000, 'table': 1000, 'iface': 'wgext_x', + 'redirect_port': None}] pr = MagicMock() - pr.list_peers.return_value = [ - {'peer': 'alice', 'exit_via': 'wireguard_ext'}, - ] + pr.list_peers.return_value = [{'peer': 'alice', 'exit_via': 'conn_wg'}] pr.get_peer.return_value = {'peer': 'alice', 'ip': '172.20.0.50/32'} - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm_with_connections(conns)) with patch.object(mgr, '_add_mark_rule') as mock_mark, \ patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() @@ -721,14 +739,14 @@ class TestApplyRoutes(unittest.TestCase): mock_mark.assert_called() call_args = mock_mark.call_args[0] self.assertEqual(call_args[0], '172.20.0.50') # IP without CIDR + self.assertEqual(call_args[1], 0x1000) # the connection's mark def test_peer_with_default_exit_skips_mark_rule(self): """Peers on default exit must not generate mark rules.""" pr = MagicMock() - pr.list_peers.return_value = [ - {'peer': 'bob', 'exit_via': 'default'}, - ] - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + pr.list_peers.return_value = [{'peer': 'bob', 'exit_via': 'default'}] + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm_with_connections([])) with patch.object(mgr, '_add_mark_rule') as mock_mark, \ patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() @@ -744,6 +762,171 @@ class TestApplyRoutes(unittest.TestCase): self.assertIsInstance(result['rules_applied'], int) +# --------------------------------------------------------------------------- +# apply_routes — instance-aware routing (connectivity v2) +# --------------------------------------------------------------------------- + +class TestApplyRoutesInstances(unittest.TestCase): + """apply_routes must drive routing from connection instances, so two + instances of the same type route through distinct tables/marks without + collision, and each peer gets its own connection's mark.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _cm(self, connections): + cm = MagicMock() + cm.get_identity.return_value = {'cell_name': 't', 'ip_range': '172.20.0.0/16'} + cm.list_connections.return_value = connections + return cm + + @staticmethod + def _docker_args(call): + """Strip the `docker exec ` prefix from a call.""" + args = call.args[0] + # args == ['docker', 'exec', CONTAINER, 'ip'|'iptables', ] + return args[4:] + + def test_two_wireguard_ext_instances_distinct_tables_no_collision(self): + conns = [ + {'id': 'conn_a', 'type': 'wireguard_ext', 'enabled': True, + 'mark': 0x1000, 'table': 1000, 'iface': 'wgext_a', 'redirect_port': None}, + {'id': 'conn_b', 'type': 'wireguard_ext', 'enabled': True, + 'mark': 0x1010, 'table': 1001, 'iface': 'wgext_b', 'redirect_port': None}, + ] + pr = MagicMock() + pr.list_peers.return_value = [] + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm(conns)) + + with patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + + rule_adds = [] + for c in mock_sp.run.call_args_list: + args = self._docker_args(c) + if args[:3] == ['rule', 'add', 'fwmark']: + rule_adds.append(args) + + # One ip rule per instance, each pointing its own mark at its own table. + pairs = {(a[3], a[5]) for a in rule_adds} # (fwmark_hex, table) + self.assertIn(('0x1000', '1000'), pairs) + self.assertIn(('0x1010', '1001'), pairs) + # Marks and tables must be distinct — no collision. + marks = [a[3] for a in rule_adds] + tables = [a[5] for a in rule_adds] + self.assertEqual(len(set(marks)), len(marks)) + self.assertEqual(len(set(tables)), len(tables)) + + def test_two_peers_on_two_instances_get_different_marks(self): + conns = [ + {'id': 'conn_a', 'type': 'wireguard_ext', 'enabled': True, + 'mark': 0x1000, 'table': 1000, 'iface': 'wgext_a', 'redirect_port': None}, + {'id': 'conn_b', 'type': 'wireguard_ext', 'enabled': True, + 'mark': 0x1010, 'table': 1001, 'iface': 'wgext_b', 'redirect_port': None}, + ] + pr = MagicMock() + pr.list_peers.return_value = [ + {'peer': 'alice', 'exit_via': 'conn_a'}, + {'peer': 'bob', 'exit_via': 'conn_b'}, + ] + pr.get_peer.side_effect = lambda n: { + 'alice': {'peer': 'alice', 'ip': '172.20.0.50/32'}, + 'bob': {'peer': 'bob', 'ip': '172.20.0.51/32'}, + }[n] + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm(conns)) + + marks_by_ip = {} + + def capture(src_ip, mark): + marks_by_ip[src_ip] = mark + + with patch.object(mgr, '_add_mark_rule', side_effect=capture), \ + patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + + self.assertEqual(marks_by_ip['172.20.0.50'], 0x1000) + self.assertEqual(marks_by_ip['172.20.0.51'], 0x1010) + self.assertNotEqual(marks_by_ip['172.20.0.50'], marks_by_ip['172.20.0.51']) + + def test_two_redirect_instances_distinct_ports(self): + """Two redirect-style instances REDIRECT their peers to distinct ports.""" + conns = [ + {'id': 'conn_p1', 'type': 'proxy', 'enabled': True, + 'mark': 0x1000, 'table': 1000, 'iface': None, 'redirect_port': 9100}, + {'id': 'conn_p2', 'type': 'proxy', 'enabled': True, + 'mark': 0x1010, 'table': 1001, 'iface': None, 'redirect_port': 9101}, + ] + pr = MagicMock() + pr.list_peers.return_value = [ + {'peer': 'alice', 'exit_via': 'conn_p1'}, + {'peer': 'bob', 'exit_via': 'conn_p2'}, + ] + pr.get_peer.side_effect = lambda n: { + 'alice': {'peer': 'alice', 'ip': '172.20.0.50/32'}, + 'bob': {'peer': 'bob', 'ip': '172.20.0.51/32'}, + }[n] + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm(conns)) + + ports_by_ip = {} + + def capture(src_ip, port): + ports_by_ip[src_ip] = port + + with patch.object(mgr, '_add_redirect', side_effect=capture), \ + patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + + self.assertEqual(ports_by_ip['172.20.0.50'], 9100) + self.assertEqual(ports_by_ip['172.20.0.51'], 9101) + + def test_single_instance_equivalence(self): + """An already-migrated cell with one instance per type yields the same + effective rules as before: one ip rule (mark→table), one mark per + peer, killswitch on the iface-based instance.""" + conns = [ + {'id': 'conn_wg', 'type': 'wireguard_ext', 'enabled': True, + 'mark': 0x1000, 'table': 1000, 'iface': 'wgext_x', 'redirect_port': None}, + ] + pr = MagicMock() + pr.list_peers.return_value = [{'peer': 'alice', 'exit_via': 'conn_wg'}] + pr.get_peer.return_value = {'peer': 'alice', 'ip': '172.20.0.50/32'} + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm(conns)) + + with patch.object(mgr, '_add_killswitch') as mock_ks, \ + patch.object(mgr, '_add_mark_rule') as mock_mark, \ + patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + + mock_mark.assert_called_once_with('172.20.0.50', 0x1000) + mock_ks.assert_called_once_with(0x1000, 'wgext_x') + + def test_disabled_instance_is_skipped(self): + conns = [ + {'id': 'conn_off', 'type': 'wireguard_ext', 'enabled': False, + 'mark': 0x1000, 'table': 1000, 'iface': 'wgext_x', 'redirect_port': None}, + ] + pr = MagicMock() + pr.list_peers.return_value = [] + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm(conns)) + with patch.object(mgr, '_add_killswitch') as mock_ks, \ + patch.object(cm_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = _mock_subprocess_ok() + mgr.apply_routes() + mock_ks.assert_not_called() + + # --------------------------------------------------------------------------- # _exit_status — status string + store-service bridge # --------------------------------------------------------------------------- diff --git a/tests/test_connectivity_proxy.py b/tests/test_connectivity_proxy.py index 416f22a..1a33a3c 100644 --- a/tests/test_connectivity_proxy.py +++ b/tests/test_connectivity_proxy.py @@ -37,6 +37,7 @@ def _make_manager(tmp_dir=None, peer_registry=_SENTINEL, config_manager=None): 'exits': {}, 'peer_exit_map': {}, } config_manager.get_installed_services.return_value = {} + config_manager.list_connections.return_value = [] if peer_registry is _SENTINEL: peer_registry = MagicMock() @@ -265,13 +266,26 @@ class TestApplyRoutesProxy(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tmp, ignore_errors=True) - def test_proxy_peer_gets_redirect_to_12345(self): + @staticmethod + def _proxy_conn(redirect_port=9100, mark=0x1000, table=1000): + return {'id': 'conn_proxy', 'type': 'proxy', 'enabled': True, + 'mark': mark, 'table': table, 'iface': None, + 'redirect_port': redirect_port} + + def _cm(self, connections): + cm = MagicMock() + cm.get_identity.return_value = {'cell_name': 't', 'ip_range': '172.20.0.0/16'} + cm.list_connections.return_value = connections + cm.get_installed_services.return_value = {} + return cm + + def test_proxy_peer_gets_redirect_to_instance_port(self): + conn = self._proxy_conn(redirect_port=9100) pr = MagicMock() - pr.list_peers.return_value = [ - {'peer': 'bob', 'exit_via': 'proxy'}, - ] + pr.list_peers.return_value = [{'peer': 'bob', 'exit_via': 'conn_proxy'}] pr.get_peer.return_value = {'peer': 'bob', 'ip': '172.20.0.60/32'} - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm([conn])) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() mgr.apply_routes() @@ -281,16 +295,16 @@ class TestApplyRoutesProxy(unittest.TestCase): ] self.assertEqual(len(redirect_calls), 1) args = redirect_calls[0].args[0] - self.assertEqual(args[args.index('--to-ports') + 1], '12345') + self.assertEqual(args[args.index('--to-ports') + 1], '9100') self.assertIn('172.20.0.60', args) - def test_proxy_peer_gets_mark_0x50(self): + def test_proxy_peer_gets_instance_mark(self): + conn = self._proxy_conn(mark=0x1020) pr = MagicMock() - pr.list_peers.return_value = [ - {'peer': 'bob', 'exit_via': 'proxy'}, - ] + pr.list_peers.return_value = [{'peer': 'bob', 'exit_via': 'conn_proxy'}] pr.get_peer.return_value = {'peer': 'bob', 'ip': '172.20.0.60/32'} - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm([conn])) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() mgr.apply_routes() @@ -300,29 +314,32 @@ class TestApplyRoutesProxy(unittest.TestCase): ] self.assertEqual(len(mark_calls), 1) args = mark_calls[0].args[0] - self.assertEqual(args[args.index('--set-mark') + 1], hex(0x50)) + self.assertEqual(args[args.index('--set-mark') + 1], hex(0x1020)) - def test_ip_rule_added_for_proxy_table_150(self): - mgr = _make_manager(tmp_dir=self.tmp) + def test_ip_rule_added_for_instance_table(self): + conn = self._proxy_conn(mark=0x1030, table=1234) + mgr = _make_manager(tmp_dir=self.tmp, config_manager=self._cm([conn])) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') mgr.apply_routes() rule_adds = [ c for c in mock_sp.run.call_args_list if 'rule' in c.args[0] and 'add' in c.args[0] - and hex(0x50) in c.args[0] + and hex(0x1030) in c.args[0] ] self.assertEqual(len(rule_adds), 1) - self.assertIn('150', rule_adds[0].args[0]) + self.assertIn('1234', rule_adds[0].args[0]) - def test_tor_redirect_still_uses_9040(self): - """Regression: tor redirect must be unaffected by the new exits.""" + def test_tor_redirect_uses_instance_port(self): + """A tor connection instance REDIRECTs to its own allocated port.""" + conn = {'id': 'conn_tor', 'type': 'tor', 'enabled': True, + 'mark': 0x1000, 'table': 1000, 'iface': None, + 'redirect_port': 9100} pr = MagicMock() - pr.list_peers.return_value = [ - {'peer': 'carol', 'exit_via': 'tor'}, - ] + pr.list_peers.return_value = [{'peer': 'carol', 'exit_via': 'conn_tor'}] pr.get_peer.return_value = {'peer': 'carol', 'ip': '172.20.0.70/32'} - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm([conn])) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() mgr.apply_routes() @@ -332,7 +349,7 @@ class TestApplyRoutesProxy(unittest.TestCase): ] self.assertEqual(len(redirect_calls), 1) args = redirect_calls[0].args[0] - self.assertEqual(args[args.index('--to-ports') + 1], '9040') + self.assertEqual(args[args.index('--to-ports') + 1], '9100') # --------------------------------------------------------------------------- @@ -340,39 +357,32 @@ class TestApplyRoutesProxy(unittest.TestCase): # --------------------------------------------------------------------------- class TestEgressManagerMirror(unittest.TestCase): + """Egress now resolves a service's (mark, table, redirect_port) from the + connection instance it routes through — no per-type MARKS/TABLES tables.""" - def test_exit_types_include_sshuttle_and_proxy(self): - self.assertIn('sshuttle', em_module.EXIT_TYPES) - self.assertIn('proxy', em_module.EXIT_TYPES) - - def test_marks_do_not_collide_with_connectivity(self): - self.assertEqual(em_module.MARKS['sshuttle'], 0x140) - self.assertEqual(em_module.MARKS['proxy'], 0x150) - self.assertNotIn(em_module.MARKS['sshuttle'], - ConnectivityManager.MARKS.values()) - self.assertNotIn(em_module.MARKS['proxy'], - ConnectivityManager.MARKS.values()) - - def test_tables(self): - self.assertEqual(em_module.TABLES['sshuttle'], 240) - self.assertEqual(em_module.TABLES['proxy'], 250) - - def _make_egress(self, exit_via): + def _make_egress(self, connection): config_manager = MagicMock() manifest = { 'id': 'svc', 'container_name': 'cell-svc', 'has_egress': True, - 'egress': {'default': exit_via, 'allowed': list(em_module.EXIT_TYPES)}, + 'egress': {'default': connection['id'], + 'allowed': [connection['type']]}, } config_manager.get_installed_services.return_value = { 'svc': {'manifest': manifest}, } config_manager.configs = {'egress_overrides': {}} + config_manager.list_connections.return_value = [connection] + config_manager.get_connection.side_effect = \ + lambda cid: connection if cid == connection['id'] else None return em_module.EgressManager(config_manager=config_manager) - def test_apply_service_sshuttle_redirects_to_12300(self): - em = self._make_egress('sshuttle') + def test_apply_service_sshuttle_redirects_to_instance_port(self): + conn = {'id': 'conn_ssh', 'type': 'sshuttle', 'enabled': True, + 'mark': 0x1000, 'table': 1000, 'iface': None, + 'redirect_port': 9100} + em = self._make_egress(conn) with patch.object(em_module, 'subprocess') as mock_sp: mock_sp.run.return_value = MagicMock( returncode=0, stdout='172.21.0.5', stderr='') @@ -384,10 +394,13 @@ class TestEgressManagerMirror(unittest.TestCase): ] self.assertEqual(len(redirect_calls), 1) args = redirect_calls[0].args[0] - self.assertEqual(args[args.index('--to-ports') + 1], '12300') + self.assertEqual(args[args.index('--to-ports') + 1], '9100') - def test_apply_service_proxy_redirects_to_12345(self): - em = self._make_egress('proxy') + def test_apply_service_proxy_redirects_to_instance_port(self): + conn = {'id': 'conn_px', 'type': 'proxy', 'enabled': True, + 'mark': 0x1010, 'table': 1001, 'iface': None, + 'redirect_port': 9101} + em = self._make_egress(conn) with patch.object(em_module, 'subprocess') as mock_sp: mock_sp.run.return_value = MagicMock( returncode=0, stdout='172.21.0.5', stderr='') @@ -399,7 +412,24 @@ class TestEgressManagerMirror(unittest.TestCase): ] self.assertEqual(len(redirect_calls), 1) args = redirect_calls[0].args[0] - self.assertEqual(args[args.index('--to-ports') + 1], '12345') + self.assertEqual(args[args.index('--to-ports') + 1], '9101') + + def test_apply_service_uses_connection_mark(self): + conn = {'id': 'conn_px', 'type': 'proxy', 'enabled': True, + 'mark': 0x1010, 'table': 1001, 'iface': None, + 'redirect_port': 9101} + em = self._make_egress(conn) + with patch.object(em_module, 'subprocess') as mock_sp: + mock_sp.run.return_value = MagicMock( + returncode=0, stdout='172.21.0.5', stderr='') + em.apply_service('svc') + mark_calls = [ + c for c in mock_sp.run.call_args_list + if 'MARK' in c.args[0] and '--set-mark' in c.args[0] + ] + self.assertGreater(len(mark_calls), 0) + args = mark_calls[0].args[0] + self.assertEqual(args[args.index('--set-mark') + 1], hex(0x1010)) # --------------------------------------------------------------------------- diff --git a/tests/test_connectivity_sshuttle.py b/tests/test_connectivity_sshuttle.py index ba2e495..2498fd2 100644 --- a/tests/test_connectivity_sshuttle.py +++ b/tests/test_connectivity_sshuttle.py @@ -47,6 +47,7 @@ def _make_manager(tmp_dir=None, peer_registry=_SENTINEL, config_manager=None, 'exits': {}, 'peer_exit_map': {}, } config_manager.get_installed_services.return_value = {} + config_manager.list_connections.return_value = [] if peer_registry is _SENTINEL: peer_registry = MagicMock() @@ -344,13 +345,26 @@ class TestApplyRoutesSshuttle(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tmp, ignore_errors=True) - def test_sshuttle_peer_gets_redirect_to_12300(self): + @staticmethod + def _ssh_conn(mark=0x1000, table=1000, redirect_port=9100): + return {'id': 'conn_ssh', 'type': 'sshuttle', 'enabled': True, + 'mark': mark, 'table': table, 'iface': None, + 'redirect_port': redirect_port} + + def _cm(self, connections): + cm = MagicMock() + cm.get_identity.return_value = {'cell_name': 't', 'ip_range': '172.20.0.0/16'} + cm.list_connections.return_value = connections + cm.get_installed_services.return_value = {} + return cm + + def test_sshuttle_peer_gets_redirect_to_instance_port(self): + conn = self._ssh_conn(redirect_port=9100) pr = MagicMock() - pr.list_peers.return_value = [ - {'peer': 'alice', 'exit_via': 'sshuttle'}, - ] + pr.list_peers.return_value = [{'peer': 'alice', 'exit_via': 'conn_ssh'}] pr.get_peer.return_value = {'peer': 'alice', 'ip': '172.20.0.50/32'} - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm([conn])) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() mgr.apply_routes() @@ -361,16 +375,16 @@ class TestApplyRoutesSshuttle(unittest.TestCase): self.assertEqual(len(redirect_calls), 1) args = redirect_calls[0].args[0] self.assertIn('--to-ports', args) - self.assertEqual(args[args.index('--to-ports') + 1], '12300') + self.assertEqual(args[args.index('--to-ports') + 1], '9100') self.assertIn('172.20.0.50', args) - def test_sshuttle_peer_gets_mark_0x40(self): + def test_sshuttle_peer_gets_instance_mark(self): + conn = self._ssh_conn(mark=0x1040) pr = MagicMock() - pr.list_peers.return_value = [ - {'peer': 'alice', 'exit_via': 'sshuttle'}, - ] + pr.list_peers.return_value = [{'peer': 'alice', 'exit_via': 'conn_ssh'}] pr.get_peer.return_value = {'peer': 'alice', 'ip': '172.20.0.50/32'} - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, + config_manager=self._cm([conn])) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() mgr.apply_routes() @@ -380,20 +394,21 @@ class TestApplyRoutesSshuttle(unittest.TestCase): ] self.assertEqual(len(mark_calls), 1) args = mark_calls[0].args[0] - self.assertEqual(args[args.index('--set-mark') + 1], hex(0x40)) + self.assertEqual(args[args.index('--set-mark') + 1], hex(0x1040)) - def test_ip_rule_added_for_sshuttle_table_140(self): - mgr = _make_manager(tmp_dir=self.tmp) + def test_ip_rule_added_for_instance_table(self): + conn = self._ssh_conn(mark=0x1040, table=1399) + mgr = _make_manager(tmp_dir=self.tmp, config_manager=self._cm([conn])) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = MagicMock(returncode=1, stdout='', stderr='') mgr.apply_routes() rule_adds = [ c for c in mock_sp.run.call_args_list if 'rule' in c.args[0] and 'add' in c.args[0] - and hex(0x40) in c.args[0] + and hex(0x1040) in c.args[0] ] self.assertEqual(len(rule_adds), 1) - self.assertIn('140', rule_adds[0].args[0]) + self.assertIn('1399', rule_adds[0].args[0]) def test_no_killswitch_for_sshuttle(self): """sshuttle has no exit iface — _add_killswitch must skip it.""" @@ -481,19 +496,37 @@ class TestSetPeerExitSshuttle(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tmp, ignore_errors=True) - def test_sshuttle_is_a_valid_exit_type(self): + def test_legacy_sshuttle_type_resolves_to_instance(self): + """Back-compat shim: setting exit to the legacy 'sshuttle' type resolves + to the single sshuttle connection instance.""" + conn = {'id': 'conn_ssh', 'type': 'sshuttle'} + cm = MagicMock() + cm.get_identity.return_value = {'ip_range': '172.20.0.0/16'} + cm.list_connections.return_value = [conn] + cm.get_installed_services.return_value = {} pr = MagicMock() - pr.set_peer_exit_via.return_value = True pr.list_peers.return_value = [] - mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr) + pr.get_peer.return_value = {'peer': 'alice', 'exit_via': 'conn_ssh'} + + def _set(name, value): + # mimic the real registry shim resolution + return value in ('default', 'conn_ssh', 'sshuttle') + pr.set_peer_exit_via.side_effect = _set + mgr = _make_manager(tmp_dir=self.tmp, peer_registry=pr, config_manager=cm) with patch.object(cm_module, 'subprocess') as mock_sp: mock_sp.run.return_value = _mock_subprocess_ok() result = mgr.set_peer_exit('alice', 'sshuttle') self.assertTrue(result['ok']) - def test_peer_registry_accepts_sshuttle(self): + def test_peer_registry_accepts_sshuttle_legacy_type(self): + """The peer registry resolves a legacy 'sshuttle' type to its instance id.""" from peer_registry import PeerRegistry - self.assertIn('sshuttle', PeerRegistry.VALID_EXIT_VIA) + cm = MagicMock() + cm.list_connections.return_value = [{'id': 'conn_ssh', 'type': 'sshuttle'}] + reg = PeerRegistry(data_dir=self.tmp, config_dir=self.tmp, config_manager=cm) + reg.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + self.assertTrue(reg.set_peer_exit_via('alice', 'sshuttle')) + self.assertEqual(reg.get_peer('alice')['exit_via'], 'conn_ssh') # --------------------------------------------------------------------------- diff --git a/tests/test_egress_manager.py b/tests/test_egress_manager.py index 3139725..a25ef40 100644 --- a/tests/test_egress_manager.py +++ b/tests/test_egress_manager.py @@ -1,49 +1,85 @@ """ Tests for EgressManager — per-service egress enforcement via host iptables. +Connectivity v2: a service routes through a connection *instance* (by id), +sharing the connection's fwmark / routing table / redirect port. The egress +override map is service_id → connection_id, and (mark, table, redirect_port) +are resolved from ConnectivityManager.get_connection(id). EgressManager no +longer owns its own per-type MARKS/TABLES. + All subprocess calls (iptables, iptables-save, iptables-restore, ip rule, -docker inspect) and config_manager state are mocked so these tests run -without any live infrastructure or root privileges. +docker inspect) and manager state are mocked so these tests run without any +live infrastructure or root privileges. """ import os import sys import unittest -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'api')) import egress_manager as em_module -from egress_manager import EgressManager, MARKS, TABLES, EXIT_TYPES, EGRESS_CHAIN +from egress_manager import EgressManager, EGRESS_CHAIN # --------------------------------------------------------------------------- -# Helpers +# Connection fixtures — mirror the v2 allocator output. # --------------------------------------------------------------------------- -def _make_manager(installed=None, overrides=None): - """Build an EgressManager backed by a mock config_manager.""" +CONNECTIONS = { + 'conn_wg': { + 'id': 'conn_wg', 'type': 'wireguard_ext', 'name': 'Work VPN', + 'enabled': True, 'mark': 0x1000, 'table': 1000, + 'iface': 'wgext_aaaa', 'redirect_port': None, + }, + 'conn_ovpn': { + 'id': 'conn_ovpn', 'type': 'openvpn', 'name': 'OVPN', + 'enabled': True, 'mark': 0x1010, 'table': 1001, + 'iface': 'ovpn_bbbb', 'redirect_port': None, + }, + 'conn_tor': { + 'id': 'conn_tor', 'type': 'tor', 'name': 'Tor', + 'enabled': True, 'mark': 0x1020, 'table': 1002, + 'iface': None, 'redirect_port': 9100, + }, +} + + +def _make_manager(installed=None, overrides=None, connections=None): + """Build an EgressManager backed by mock config + connectivity managers.""" cm = MagicMock() cm.get_installed_services.return_value = installed or {} - # Wire up configs dict so _get_egress_overrides / _set_egress_overrides work cm.configs = {'egress_overrides': overrides or {}} cm._save_all_configs = MagicMock() - return EgressManager(config_manager=cm), cm + + conns = connections if connections is not None else CONNECTIONS + conn_list = list(conns.values()) + cm.list_connections.return_value = conn_list + cm.get_connection.side_effect = lambda cid: conns.get(cid) + + connectivity = MagicMock() + connectivity.list_connections.return_value = conn_list + connectivity.get_connection.side_effect = lambda cid: conns.get(cid) + + mgr = EgressManager(config_manager=cm, connectivity_manager=connectivity) + return mgr, cm def _subprocess_ok(stdout=''): - """Return a MagicMock simulating a successful subprocess.run result.""" return MagicMock(returncode=0, stdout=stdout, stderr='') def _subprocess_fail(stderr='error', stdout=''): - """Return a MagicMock simulating a failed subprocess.run result.""" return MagicMock(returncode=1, stdout=stdout, stderr=stderr) -def _make_manifest(has_egress=True, egress_default='wireguard_ext', +def _make_manifest(has_egress=True, egress_default='conn_wg', allowed=None, container_name='cell-myapp'): - """Return a minimal manifest dict with optional egress configuration.""" + """Return a minimal manifest dict with optional egress configuration. + + `allowed` is a list of connection *types* (manifests are type-scoped). + """ m = { 'id': 'myapp', 'name': 'My App', @@ -53,7 +89,8 @@ def _make_manifest(has_egress=True, egress_default='wireguard_ext', m['has_egress'] = True m['egress'] = { 'default': egress_default, - 'allowed': allowed if allowed is not None else list(EXIT_TYPES), + 'allowed': allowed if allowed is not None + else ['wireguard_ext', 'openvpn', 'tor', 'sshuttle', 'proxy'], } else: m['has_egress'] = False @@ -61,32 +98,26 @@ def _make_manifest(has_egress=True, egress_default='wireguard_ext', def _installed_with_manifest(manifest, service_id='myapp'): - """Return an installed-services dict containing one service record.""" return {service_id: {'id': service_id, 'manifest': manifest}} # --------------------------------------------------------------------------- -# 1. test_apply_service_default_exit_no_iptables_calls +# 1. default exit → no iptables rule-adding # --------------------------------------------------------------------------- class TestApplyServiceDefaultExit(unittest.TestCase): def test_apply_service_default_exit_no_iptables_calls(self): - """When egress.default is 'default', apply_service must not touch iptables.""" manifest = _make_manifest(egress_default='default') mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) with patch('subprocess.run') as mock_run: - # docker inspect must return an IP so we don't fail earlier mock_run.return_value = _subprocess_ok(stdout='172.20.0.50\n') result = mgr.apply_service('myapp') self.assertTrue(result['ok']) self.assertEqual(result.get('exit_via'), 'default') - # No iptables rule-insertion or mark call should have been made. - # iptables-save from clear_service is allowed; we only check that - # no iptables -A / -I (rule-adding) calls were made. rule_add_calls = [ c for c in mock_run.call_args_list if c.args and c.args[0][:1] == ['iptables'] @@ -96,30 +127,25 @@ class TestApplyServiceDefaultExit(unittest.TestCase): # --------------------------------------------------------------------------- -# 2. test_apply_service_wireguard_ext_adds_mark_rule +# 2. wireguard_ext connection → mark rule with the connection's own mark # --------------------------------------------------------------------------- class TestApplyServiceWireguardExt(unittest.TestCase): def test_apply_service_wireguard_ext_adds_mark_rule(self): - """wireguard_ext exit must add a mangle MARK rule with 0x110 and the correct comment.""" - manifest = _make_manifest(egress_default='wireguard_ext') + manifest = _make_manifest(egress_default='conn_wg') mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) calls_made = [] def fake_run(cmd, **kwargs): calls_made.append(cmd) - # docker inspect → return IP if 'docker' in cmd and 'inspect' in cmd: return _subprocess_ok(stdout='172.20.0.50\n') - # iptables-save → empty ruleset if 'iptables-save' in cmd: return _subprocess_ok(stdout='') - # iptables-restore → success if 'iptables-restore' in cmd: return _subprocess_ok() - # ip rule del → fail (none to delete) if cmd[:3] == ['ip', 'rule', 'del']: return _subprocess_fail() return _subprocess_ok() @@ -128,29 +154,35 @@ class TestApplyServiceWireguardExt(unittest.TestCase): result = mgr.apply_service('myapp') self.assertTrue(result['ok'], result) - self.assertEqual(result['exit_via'], 'wireguard_ext') + self.assertEqual(result['exit_via'], 'conn_wg') - # Find the mangle MARK -A call mark_calls = [ c for c in calls_made if 'iptables' in str(c) and 'MARK' in c and '--set-mark' in c ] self.assertGreater(len(mark_calls), 0, 'No MARK rule was added') mark_cmd = ' '.join(mark_calls[0]) - self.assertIn('0x110', mark_cmd) + self.assertIn('0x1000', mark_cmd) # the connection's mark self.assertIn('pic-egr-myapp', mark_cmd) self.assertIn('mangle', mark_cmd) + # The ip rule must point fwmark→the connection's table. + ip_rule_add = [ + c for c in calls_made + if c[:3] == ['ip', 'rule', 'add'] + ] + self.assertGreater(len(ip_rule_add), 0) + self.assertIn('1000', ' '.join(ip_rule_add[0])) # table + # --------------------------------------------------------------------------- -# 3. test_apply_service_openvpn_adds_mark_rule +# 3. openvpn connection → its own mark # --------------------------------------------------------------------------- class TestApplyServiceOpenVPN(unittest.TestCase): def test_apply_service_openvpn_adds_mark_rule(self): - """openvpn exit must add a mangle MARK rule with 0x120.""" - manifest = _make_manifest(egress_default='openvpn') + manifest = _make_manifest(egress_default='conn_ovpn') mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) calls_made = [] @@ -171,25 +203,24 @@ class TestApplyServiceOpenVPN(unittest.TestCase): result = mgr.apply_service('myapp') self.assertTrue(result['ok'], result) - self.assertEqual(result['exit_via'], 'openvpn') + self.assertEqual(result['exit_via'], 'conn_ovpn') mark_calls = [ c for c in calls_made if 'iptables' in str(c) and 'MARK' in c and '--set-mark' in c ] self.assertGreater(len(mark_calls), 0) - self.assertIn('0x120', ' '.join(mark_calls[0])) + self.assertIn('0x1010', ' '.join(mark_calls[0])) # --------------------------------------------------------------------------- -# 4. test_apply_service_tor_adds_mark_and_redirect +# 4. tor (redirect-style) connection → mark + REDIRECT to its redirect_port # --------------------------------------------------------------------------- class TestApplyServiceTor(unittest.TestCase): def test_apply_service_tor_adds_mark_and_redirect(self): - """tor exit must add a mangle MARK 0x130 AND a nat REDIRECT to port 9040.""" - manifest = _make_manifest(egress_default='tor') + manifest = _make_manifest(egress_default='conn_tor') mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) calls_made = [] @@ -210,14 +241,14 @@ class TestApplyServiceTor(unittest.TestCase): result = mgr.apply_service('myapp') self.assertTrue(result['ok'], result) - self.assertEqual(result['exit_via'], 'tor') + self.assertEqual(result['exit_via'], 'conn_tor') mark_calls = [ c for c in calls_made if 'iptables' in str(c) and 'MARK' in c and '--set-mark' in c ] self.assertGreater(len(mark_calls), 0, 'No MARK rule found') - self.assertIn('0x130', ' '.join(mark_calls[0])) + self.assertIn('0x1020', ' '.join(mark_calls[0])) redirect_calls = [ c for c in calls_made @@ -225,24 +256,23 @@ class TestApplyServiceTor(unittest.TestCase): ] self.assertGreater(len(redirect_calls), 0, 'No REDIRECT rule found') redirect_cmd = ' '.join(redirect_calls[0]) - self.assertIn('9040', redirect_cmd) + self.assertIn('9100', redirect_cmd) # the connection's redirect_port self.assertIn('nat', redirect_cmd) # --------------------------------------------------------------------------- -# 5. test_apply_service_no_container_ip_returns_error +# 5. no container IP → error # --------------------------------------------------------------------------- class TestApplyServiceNoContainerIP(unittest.TestCase): def test_apply_service_no_container_ip_returns_error(self): - """When docker inspect returns an empty IP, apply_service must return ok=False.""" - manifest = _make_manifest(egress_default='wireguard_ext') + manifest = _make_manifest(egress_default='conn_wg') mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) def fake_run(cmd, **kwargs): if 'docker' in cmd and 'inspect' in cmd: - return _subprocess_ok(stdout='\n') # empty IP + return _subprocess_ok(stdout='\n') if 'iptables-save' in cmd: return _subprocess_ok(stdout='') if 'iptables-restore' in cmd: @@ -257,14 +287,13 @@ class TestApplyServiceNoContainerIP(unittest.TestCase): # --------------------------------------------------------------------------- -# 6. test_apply_service_container_ip_retries +# 6. container IP retries # --------------------------------------------------------------------------- class TestApplyServiceRetries(unittest.TestCase): def test_apply_service_container_ip_retries(self): - """First docker inspect attempt fails; second succeeds — result must be ok=True.""" - manifest = _make_manifest(egress_default='wireguard_ext') + manifest = _make_manifest(egress_default='conn_wg') mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) inspect_count = [0] @@ -273,8 +302,8 @@ class TestApplyServiceRetries(unittest.TestCase): if 'docker' in cmd and 'inspect' in cmd: inspect_count[0] += 1 if inspect_count[0] == 1: - return _subprocess_ok(stdout='\n') # first attempt: empty - return _subprocess_ok(stdout='172.20.0.50\n') # second: success + return _subprocess_ok(stdout='\n') + return _subprocess_ok(stdout='172.20.0.50\n') if 'iptables-save' in cmd: return _subprocess_ok(stdout='') if 'iptables-restore' in cmd: @@ -284,7 +313,7 @@ class TestApplyServiceRetries(unittest.TestCase): return _subprocess_ok() with patch('subprocess.run', side_effect=fake_run): - with patch('time.sleep'): # skip actual delays + with patch('time.sleep'): result = mgr.apply_service('myapp') self.assertTrue(result['ok'], result) @@ -292,13 +321,12 @@ class TestApplyServiceRetries(unittest.TestCase): # --------------------------------------------------------------------------- -# 7. test_has_egress_false_skips_rules +# 7. has_egress False → skipped # --------------------------------------------------------------------------- class TestHasEgressFalse(unittest.TestCase): def test_has_egress_false_skips_rules(self): - """A manifest with has_egress=False must skip rules and return skipped=True.""" manifest = _make_manifest(has_egress=False) mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) @@ -309,8 +337,6 @@ class TestHasEgressFalse(unittest.TestCase): self.assertTrue(result['ok']) self.assertTrue(result.get('skipped')) - # No iptables rule-insertion call should have been made. - # iptables-save from clear_service is permitted; only check no -A/-I. rule_add_calls = [ c for c in mock_run.call_args_list if c.args and c.args[0][:1] == ['iptables'] @@ -320,22 +346,18 @@ class TestHasEgressFalse(unittest.TestCase): # --------------------------------------------------------------------------- -# 8. test_has_egress_missing_egress_block_skips +# 8. has_egress True but no egress block → skipped # --------------------------------------------------------------------------- class TestHasEgressMissingBlock(unittest.TestCase): def test_has_egress_missing_egress_block_skips(self): - """has_egress=True but no 'egress' dict → must skip (skipped=True).""" manifest = { 'id': 'myapp', 'container_name': 'cell-myapp', 'has_egress': True, - # 'egress' key intentionally absent } - mgr, _ = _make_manager( - installed=_installed_with_manifest(manifest) - ) + mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) with patch('subprocess.run') as mock_run: mock_run.return_value = _subprocess_ok(stdout='') @@ -346,23 +368,21 @@ class TestHasEgressMissingBlock(unittest.TestCase): # --------------------------------------------------------------------------- -# 9. test_clear_service_removes_tagged_rules +# 9. clear_service removes only the tagged rules # --------------------------------------------------------------------------- class TestClearService(unittest.TestCase): def test_clear_service_removes_tagged_rules(self): - """iptables-restore is called with the tagged lines removed.""" mgr, _ = _make_manager() mangle_rules = ( - '-A PIC_EGRESS -s 172.20.0.50 -j MARK --set-mark 0x110 ' + '-A PIC_EGRESS -s 172.20.0.50 -j MARK --set-mark 0x1000 ' '-m comment --comment "pic-egr-myapp"\n' - '-A PIC_EGRESS -s 172.20.0.99 -j MARK --set-mark 0x110 ' + '-A PIC_EGRESS -s 172.20.0.99 -j MARK --set-mark 0x1000 ' '-m comment --comment "pic-egr-otherapp"\n' ) nat_rules = '' - restore_inputs = {} def fake_run(cmd, input=None, **kwargs): @@ -382,28 +402,25 @@ class TestClearService(unittest.TestCase): result = mgr.clear_service('myapp') self.assertTrue(result['ok']) - # The restored mangle rules must not contain myapp's tag restored = restore_inputs.get('mangle', '') self.assertNotIn('pic-egr-myapp', restored) - # But the other service's rules must be preserved self.assertIn('pic-egr-otherapp', restored) # --------------------------------------------------------------------------- -# 10. test_set_service_exit_rejects_not_in_allowed +# 10. set_service_exit rejects a connection whose type is not in allowed # --------------------------------------------------------------------------- class TestSetServiceExitRejectNotAllowed(unittest.TestCase): - def test_set_service_exit_rejects_not_in_allowed(self): - """Exit type not in manifest's allowed list must return ok=False.""" + def test_set_service_exit_rejects_type_not_in_allowed(self): manifest = _make_manifest( egress_default='default', - allowed=['default', 'tor'], # wireguard_ext not in allowed + allowed=['tor'], # wireguard_ext not allowed ) mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) - result = mgr.set_service_exit('myapp', 'wireguard_ext') + result = mgr.set_service_exit('myapp', 'conn_wg') self.assertFalse(result['ok']) self.assertIn('error', result) @@ -411,47 +428,71 @@ class TestSetServiceExitRejectNotAllowed(unittest.TestCase): # --------------------------------------------------------------------------- -# 11. test_set_service_exit_persists_and_applies +# 11. set_service_exit persists the connection id and applies # --------------------------------------------------------------------------- class TestSetServiceExitPersistsAndApplies(unittest.TestCase): def test_set_service_exit_persists_and_applies(self): - """Valid override must be persisted to config_manager and apply_service called.""" - manifest = _make_manifest(egress_default='default', allowed=list(EXIT_TYPES)) + manifest = _make_manifest(egress_default='default') mgr, cm = _make_manager(installed=_installed_with_manifest(manifest)) apply_calls = [] - original_apply = mgr.apply_service + mgr.apply_service = lambda sid: apply_calls.append(sid) or { + 'ok': True, 'exit_via': 'conn_tor'} - def fake_apply(sid): - apply_calls.append(sid) - return {'ok': True, 'exit_via': 'tor'} + result = mgr.set_service_exit('myapp', 'conn_tor') - mgr.apply_service = fake_apply + self.assertTrue(result['ok'], result) + self.assertIn('myapp', apply_calls) + cm._save_all_configs.assert_called() + self.assertEqual(cm.configs['egress_overrides'].get('myapp'), 'conn_tor') + + def test_set_service_exit_default_clears_override(self): + manifest = _make_manifest(egress_default='conn_wg') + mgr, cm = _make_manager( + installed=_installed_with_manifest(manifest), + overrides={'myapp': 'conn_tor'}, + ) + mgr.apply_service = lambda sid: {'ok': True, 'exit_via': 'default'} + + result = mgr.set_service_exit('myapp', 'default') + + self.assertTrue(result['ok']) + self.assertEqual(cm.configs['egress_overrides'].get('myapp'), 'default') + + def test_set_service_exit_unknown_connection_rejected(self): + manifest = _make_manifest(egress_default='default') + mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) + + result = mgr.set_service_exit('myapp', 'conn_ghost') + + self.assertFalse(result['ok']) + self.assertIn('unknown connection', result['error']) + + def test_set_service_exit_legacy_type_resolves_to_single_instance(self): + """Back-compat shim: a legacy type resolves to the one instance of it.""" + manifest = _make_manifest(egress_default='default', allowed=['tor']) + mgr, cm = _make_manager(installed=_installed_with_manifest(manifest)) + mgr.apply_service = lambda sid: {'ok': True} result = mgr.set_service_exit('myapp', 'tor') self.assertTrue(result['ok'], result) - # apply_service was called - self.assertIn('myapp', apply_calls) - # override was persisted - cm._save_all_configs.assert_called() - self.assertEqual(cm.configs['egress_overrides'].get('myapp'), 'tor') + self.assertEqual(cm.configs['egress_overrides'].get('myapp'), 'conn_tor') # --------------------------------------------------------------------------- -# 12. test_apply_all_iterates_installed_services +# 12. apply_all iterates installed services # --------------------------------------------------------------------------- class TestApplyAll(unittest.TestCase): def test_apply_all_iterates_installed_services(self): - """apply_all must call apply_service for every service with a manifest.""" manifests = { - 'svc1': _make_manifest(egress_default='wireguard_ext'), - 'svc2': _make_manifest(egress_default='openvpn'), - 'svc3': _make_manifest(egress_default='tor'), + 'svc1': _make_manifest(egress_default='conn_wg'), + 'svc2': _make_manifest(egress_default='conn_ovpn'), + 'svc3': _make_manifest(egress_default='conn_tor'), } installed = { sid: {'id': sid, 'manifest': m} @@ -469,37 +510,57 @@ class TestApplyAll(unittest.TestCase): # --------------------------------------------------------------------------- -# 13. test_marks_do_not_collide_with_connectivity_manager +# 13. service + peer on the same connection share the same mark # --------------------------------------------------------------------------- -class TestMarksNoCollision(unittest.TestCase): +class TestSharedMarkWithConnection(unittest.TestCase): - def test_marks_do_not_collide_with_connectivity_manager(self): - """EgressManager marks must be disjoint from ConnectivityManager marks.""" - connectivity_marks = {0x10, 0x20, 0x30} - egress_mark_values = set(MARKS.values()) - collision = connectivity_marks & egress_mark_values - self.assertEqual( - collision, set(), - f'Mark collision with ConnectivityManager: {collision}', - ) + def test_service_uses_connection_mark_not_a_private_table(self): + """A service routed via a connection uses that connection's mark/table, + i.e. the SAME resources a peer on that connection would use.""" + manifest = _make_manifest(egress_default='conn_wg') + mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) + + captured = {} + + def fake_run(cmd, **kwargs): + if 'docker' in cmd and 'inspect' in cmd: + return _subprocess_ok(stdout='172.20.0.50\n') + if 'iptables-save' in cmd: + return _subprocess_ok(stdout='') + if 'iptables-restore' in cmd: + return _subprocess_ok() + if cmd[:3] == ['ip', 'rule', 'del']: + return _subprocess_fail() + if cmd[:3] == ['ip', 'rule', 'add']: + captured['ip_rule'] = ' '.join(cmd) + if 'MARK' in cmd: + captured['mark'] = ' '.join(cmd) + return _subprocess_ok() + + with patch('subprocess.run', side_effect=fake_run): + mgr.apply_service('myapp') + + # conn_wg has mark 0x1000 / table 1000 — both must appear, proving the + # service inherits the connection's shared routing resources. + self.assertIn('0x1000', captured.get('mark', '')) + self.assertIn('1000', captured.get('ip_rule', '')) # --------------------------------------------------------------------------- -# 14. test_apply_service_unknown_exit_in_allowed_rejected +# 14. apply_service with an unknown connection id → error # --------------------------------------------------------------------------- class TestApplyServiceUnknownExit(unittest.TestCase): - def test_apply_service_unknown_exit_in_allowed_rejected(self): - """An egress.default value that is not a known EXIT_TYPE must return ok=False.""" + def test_apply_service_unknown_connection_rejected(self): manifest = { 'id': 'myapp', 'container_name': 'cell-myapp', 'has_egress': True, 'egress': { - 'default': 'internet_fast_lane', # unknown exit - 'allowed': ['internet_fast_lane'], + 'default': 'conn_ghost', # no such connection + 'allowed': ['wireguard_ext'], }, } mgr, _ = _make_manager(installed=_installed_with_manifest(manifest)) @@ -521,7 +582,7 @@ class TestApplyServiceUnknownExit(unittest.TestCase): # --------------------------------------------------------------------------- -# Additional coverage: _has_egress edge cases +# _has_egress edge cases # --------------------------------------------------------------------------- class TestHasEgressLogic(unittest.TestCase): @@ -530,16 +591,15 @@ class TestHasEgressLogic(unittest.TestCase): self.mgr, _ = _make_manager() def test_has_egress_both_required(self): - """Both has_egress=True and non-empty egress dict required.""" - m = {'has_egress': True, 'egress': {'default': 'tor', 'allowed': ['tor']}} + m = {'has_egress': True, 'egress': {'default': 'conn_tor', 'allowed': ['tor']}} self.assertTrue(self.mgr._has_egress(m)) def test_has_egress_false_field(self): - m = {'has_egress': False, 'egress': {'default': 'tor', 'allowed': ['tor']}} + m = {'has_egress': False, 'egress': {'default': 'conn_tor', 'allowed': ['tor']}} self.assertFalse(self.mgr._has_egress(m)) def test_has_egress_missing_has_egress_key(self): - m = {'egress': {'default': 'tor', 'allowed': ['tor']}} + m = {'egress': {'default': 'conn_tor', 'allowed': ['tor']}} self.assertFalse(self.mgr._has_egress(m)) def test_has_egress_empty_egress_dict(self): @@ -548,29 +608,41 @@ class TestHasEgressLogic(unittest.TestCase): # --------------------------------------------------------------------------- -# Additional coverage: _resolve_exit +# _resolve_exit (now returns connection ids) # --------------------------------------------------------------------------- class TestResolveExit(unittest.TestCase): def test_override_takes_precedence(self): - mgr, _ = _make_manager(overrides={'myapp': 'openvpn'}) - manifest = _make_manifest(egress_default='wireguard_ext') - self.assertEqual(mgr._resolve_exit('myapp', manifest), 'openvpn') + mgr, _ = _make_manager(overrides={'myapp': 'conn_ovpn'}) + manifest = _make_manifest(egress_default='conn_wg') + self.assertEqual(mgr._resolve_exit('myapp', manifest), 'conn_ovpn') def test_manifest_default_used_when_no_override(self): mgr, _ = _make_manager(overrides={}) - manifest = _make_manifest(egress_default='tor') - self.assertEqual(mgr._resolve_exit('myapp', manifest), 'tor') + manifest = _make_manifest(egress_default='conn_tor') + self.assertEqual(mgr._resolve_exit('myapp', manifest), 'conn_tor') def test_fallback_to_default_when_no_egress_block(self): mgr, _ = _make_manager(overrides={}) manifest = {'id': 'myapp'} self.assertEqual(mgr._resolve_exit('myapp', manifest), 'default') + def test_legacy_type_override_migrates_to_connection_id(self): + """An old override holding a type string resolves to the migrated id.""" + mgr, _ = _make_manager(overrides={'myapp': 'wireguard_ext'}) + manifest = _make_manifest(egress_default='default') + self.assertEqual(mgr._resolve_exit('myapp', manifest), 'conn_wg') + + def test_legacy_type_default_with_no_instance_falls_back(self): + """A legacy type with no matching instance falls back to 'default'.""" + mgr, _ = _make_manager(connections={}) + manifest = _make_manifest(egress_default='tor') + self.assertEqual(mgr._resolve_exit('myapp', manifest), 'default') + # --------------------------------------------------------------------------- -# Additional: apply_service with missing manifest +# apply_service with missing manifest # --------------------------------------------------------------------------- class TestApplyServiceMissingManifest(unittest.TestCase): diff --git a/tests/test_peer_registry.py b/tests/test_peer_registry.py index 09d87ec..a8f2c66 100644 --- a/tests/test_peer_registry.py +++ b/tests/test_peer_registry.py @@ -5,6 +5,7 @@ import os import json import sys from pathlib import Path +from unittest.mock import MagicMock # Add api directory to path api_dir = Path(__file__).parent.parent / 'api' @@ -107,35 +108,90 @@ class TestPeerRegistry(unittest.TestCase): with self.assertRaises(ValueError): self.registry.set_route_via('nobody', 'exit-cell') - def test_set_peer_exit_via_valid(self): + def _connectivity_cm(self, connections): + """A mock config_manager exposing v2 connection records.""" + cm = MagicMock() + cm.list_connections.return_value = connections + return cm + + def test_set_peer_exit_via_valid_connection_id(self): + conns = [{'id': 'conn_wg', 'type': 'wireguard_ext'}] + self.registry.config_manager = self._connectivity_cm(conns) self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) - result = self.registry.set_peer_exit_via('alice', 'wireguard_ext') + result = self.registry.set_peer_exit_via('alice', 'conn_wg') self.assertTrue(result) - peer = self.registry.get_peer('alice') - self.assertEqual(peer['exit_via'], 'wireguard_ext') + self.assertEqual(self.registry.get_peer('alice')['exit_via'], 'conn_wg') - def test_set_peer_exit_via_all_valid_types(self): + def test_set_peer_exit_via_default_always_valid(self): + self.registry.config_manager = self._connectivity_cm([]) self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) - for exit_type in ('default', 'wireguard_ext', 'openvpn', 'tor'): - result = self.registry.set_peer_exit_via('alice', exit_type) - self.assertTrue(result) - peer = self.registry.get_peer('alice') - self.assertEqual(peer['exit_via'], exit_type) + result = self.registry.set_peer_exit_via('alice', 'default') + self.assertTrue(result) + self.assertEqual(self.registry.get_peer('alice')['exit_via'], 'default') - def test_set_peer_exit_via_invalid_type_returns_false(self): + def test_set_peer_exit_via_legacy_type_resolves_to_instance(self): + """Back-compat shim: a legacy type resolves to the one instance of it.""" + conns = [{'id': 'conn_tor', 'type': 'tor'}] + self.registry.config_manager = self._connectivity_cm(conns) self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) - result = self.registry.set_peer_exit_via('alice', 'invalid_exit') + result = self.registry.set_peer_exit_via('alice', 'tor') + self.assertTrue(result) + self.assertEqual(self.registry.get_peer('alice')['exit_via'], 'conn_tor') + + def test_set_peer_exit_via_unknown_id_returns_false(self): + self.registry.config_manager = self._connectivity_cm([]) + self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) + result = self.registry.set_peer_exit_via('alice', 'conn_ghost') self.assertFalse(result) def test_set_peer_exit_via_nonexistent_peer_returns_false(self): + self.registry.config_manager = self._connectivity_cm([]) result = self.registry.set_peer_exit_via('nobody', 'default') self.assertFalse(result) def test_set_peer_exit_via_persists(self): + conns = [{'id': 'conn_tor', 'type': 'tor'}] + cm = self._connectivity_cm(conns) + self.registry.config_manager = cm self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) - self.registry.set_peer_exit_via('alice', 'tor') - reloaded = PeerRegistry(data_dir=self.test_dir, config_dir=self.test_dir) - self.assertEqual(reloaded.get_peer('alice')['exit_via'], 'tor') + self.registry.set_peer_exit_via('alice', 'conn_tor') + reloaded = PeerRegistry(data_dir=self.test_dir, config_dir=self.test_dir, + config_manager=cm) + self.assertEqual(reloaded.get_peer('alice')['exit_via'], 'conn_tor') + + def test_exit_via_migration_legacy_type_to_id(self): + """On load, a legacy per-type exit_via becomes the migrated instance id.""" + peers_file = os.path.join(self.test_dir, 'peers.json') + with open(peers_file, 'w') as f: + json.dump([{'peer': 'alice', 'ip': '10.0.0.5', + 'exit_via': 'wireguard_ext'}], f) + cm = self._connectivity_cm([{'id': 'conn_wg', 'type': 'wireguard_ext'}]) + reg = PeerRegistry(data_dir=self.test_dir, config_dir=self.test_dir, + config_manager=cm) + self.assertEqual(reg.get_peer('alice')['exit_via'], 'conn_wg') + + def test_exit_via_migration_unknown_type_to_default(self): + """A legacy type with no migrated instance falls back to 'default'.""" + peers_file = os.path.join(self.test_dir, 'peers.json') + with open(peers_file, 'w') as f: + json.dump([{'peer': 'alice', 'ip': '10.0.0.5', + 'exit_via': 'openvpn'}], f) + cm = self._connectivity_cm([]) # no instances + reg = PeerRegistry(data_dir=self.test_dir, config_dir=self.test_dir, + config_manager=cm) + self.assertEqual(reg.get_peer('alice')['exit_via'], 'default') + + def test_exit_via_migration_id_is_idempotent(self): + """An already-migrated id is left untouched and not re-migrated.""" + peers_file = os.path.join(self.test_dir, 'peers.json') + with open(peers_file, 'w') as f: + json.dump([{'peer': 'alice', 'ip': '10.0.0.5', + 'exit_via': 'conn_wg'}], f) + cm = self._connectivity_cm([{'id': 'conn_wg', 'type': 'wireguard_ext'}]) + reg = PeerRegistry(data_dir=self.test_dir, config_dir=self.test_dir, + config_manager=cm) + self.assertEqual(reg.get_peer('alice')['exit_via'], 'conn_wg') + self.assertFalse(reg._migrate_exit_via_to_connection_id()) def test_update_peer_updates_arbitrary_fields(self): self.registry.add_peer({'peer': 'alice', 'ip': '10.0.0.5'}) diff --git a/webui/src/services/api.js b/webui/src/services/api.js index 5d5df47..dd2097e 100644 --- a/webui/src/services/api.js +++ b/webui/src/services/api.js @@ -360,8 +360,8 @@ export const setupAPI = { // Per-service Egress API export const egressAPI = { getStatus: () => api.get('/api/egress/status'), - setServiceExit: (serviceId, exitType) => - api.put(`/api/egress/services/${serviceId}/exit`, { exit_type: exitType }), + setServiceExit: (serviceId, connectionId) => + api.put(`/api/egress/services/${serviceId}/exit`, { connection_id: connectionId }), }; // Connectivity / Exit Routing API @@ -374,7 +374,7 @@ export const connectivityAPI = { configureProxy: (cfg) => api.post('/api/connectivity/exits/proxy', cfg), applyRoutes: () => api.post('/api/connectivity/exits/apply'), getPeerExits: () => api.get('/api/connectivity/peers'), - setPeerExit: (peer_name, exit_via) => api.put(`/api/connectivity/peers/${peer_name}/exit`, { exit_via }), + setPeerExit: (peer_name, connection_id) => api.put(`/api/connectivity/peers/${peer_name}/exit`, { connection_id }), }; // Container Management API