#!/usr/bin/env python3 """ Unit tests for VaultManager Tests secure certificate management, trust systems, and Age encryption. """ import sys from pathlib import Path # Add api directory to path api_dir = Path(__file__).parent.parent / 'api' sys.path.insert(0, str(api_dir)) import unittest import tempfile import shutil import os import sys from pathlib import Path from unittest.mock import patch, MagicMock import json from datetime import datetime, timedelta import subprocess from vault_manager import VaultManager class TestVaultManager(unittest.TestCase): """Test cases for VaultManager.""" def setUp(self): """Set up test environment.""" self.test_dir = tempfile.mkdtemp() self.config_dir = os.path.join(self.test_dir, "config") self.data_dir = os.path.join(self.test_dir, "data") os.makedirs(self.config_dir, exist_ok=True) os.makedirs(self.data_dir, exist_ok=True) # Mock Age subprocess calls self.age_patcher = patch('subprocess.run') self.mock_age = self.age_patcher.start() # Mock Age key generation output mock_result = MagicMock() mock_result.stdout = "age1testkey123456789\n" mock_result.returncode = 0 self.mock_age.return_value = mock_result self.vault = VaultManager(self.config_dir, self.data_dir) # If Age keys were created (mocked), ensure the files exist # (Removed Age key checks; Fernet is now used) def tearDown(self): """Clean up test environment.""" self.age_patcher.stop() shutil.rmtree(self.test_dir) def test_init_creates_directories(self): """Test that initialization creates required directories.""" vault_dir = Path(self.data_dir) / "vault" ca_dir = vault_dir / "ca" certs_dir = vault_dir / "certs" keys_dir = vault_dir / "keys" trust_dir = vault_dir / "trust" self.assertTrue(vault_dir.exists()) self.assertTrue(ca_dir.exists()) self.assertTrue(certs_dir.exists()) self.assertTrue(keys_dir.exists()) self.assertTrue(trust_dir.exists()) def test_ca_creation(self): """Test CA creation.""" self.assertTrue(self.vault.ca_key_file.exists()) self.assertTrue(self.vault.ca_cert_file.exists()) # Verify CA certificate properties with open(self.vault.ca_cert_file, "rb") as f: from cryptography import x509 cert = x509.load_pem_x509_certificate(f.read()) # Check basic constraints basic_constraints = cert.extensions.get_extension_for_oid( x509.oid.ExtensionOID.BASIC_CONSTRAINTS ) self.assertTrue(basic_constraints.value.ca) def test_generate_certificate(self): """Test certificate generation.""" cert_info = self.vault.generate_certificate( common_name="test.example.com", domains=["test.example.com", "www.test.example.com"], key_size=2048, days=365 ) self.assertEqual(cert_info["common_name"], "test.example.com") self.assertEqual(cert_info["domains"], ["test.example.com", "www.test.example.com"]) self.assertTrue(cert_info["cert_file"]) self.assertTrue(cert_info["key_file"]) self.assertTrue(cert_info["encrypted"]) # Verify certificate file exists cert_file = Path(cert_info["cert_file"]) key_file = Path(cert_info["key_file"]) self.assertTrue(cert_file.exists()) self.assertTrue(key_file.exists()) def test_generate_certificate_without_domains(self): """Test certificate generation without domains.""" cert_info = self.vault.generate_certificate( common_name="simple.example.com" ) self.assertEqual(cert_info["common_name"], "simple.example.com") self.assertEqual(cert_info["domains"], []) def test_list_certificates(self): """Test listing certificates.""" # Generate a test certificate self.vault.generate_certificate("test.example.com") certificates = self.vault.list_certificates() self.assertEqual(len(certificates), 1) self.assertEqual(certificates[0]["common_name"], "test.example.com") self.assertFalse(certificates[0]["expired"]) def test_revoke_certificate(self): """Test certificate revocation.""" # Generate a test certificate self.vault.generate_certificate("test.example.com") # Verify certificate exists cert_file = self.vault.certs_dir / "test.example.com.crt" key_file = self.vault.certs_dir / "test.example.com.key" self.assertTrue(cert_file.exists()) self.assertTrue(key_file.exists()) # Revoke certificate result = self.vault.revoke_certificate("test.example.com") self.assertTrue(result) # Verify files are removed self.assertFalse(cert_file.exists()) self.assertFalse(key_file.exists()) def test_revoke_nonexistent_certificate(self): """Test revoking non-existent certificate.""" result = self.vault.revoke_certificate("nonexistent.example.com") self.assertTrue(result) # Should not raise exception def test_add_trusted_key(self): """Test adding trusted key.""" result = self.vault.add_trusted_key( name="test-peer", public_key="age1testkey123456789", trust_level="direct" ) self.assertTrue(result) # Verify key is added trusted_keys = self.vault.get_trusted_keys() self.assertIn("test-peer", trusted_keys) self.assertEqual(trusted_keys["test-peer"]["public_key"], "age1testkey123456789") self.assertEqual(trusted_keys["test-peer"]["trust_level"], "direct") def test_remove_trusted_key(self): """Test removing trusted key.""" # Add a trusted key first self.vault.add_trusted_key("test-peer", "age1testkey123456789") # Remove the key result = self.vault.remove_trusted_key("test-peer") self.assertTrue(result) # Verify key is removed trusted_keys = self.vault.get_trusted_keys() self.assertNotIn("test-peer", trusted_keys) def test_remove_nonexistent_trusted_key(self): """Test removing non-existent trusted key.""" result = self.vault.remove_trusted_key("nonexistent-peer") self.assertFalse(result) def test_verify_trust_chain(self): """Test trust chain verification.""" # Add a trusted key first self.vault.add_trusted_key("test-peer", "age1testkey123456789") # Verify trust chain result = self.vault.verify_trust_chain( peer_name="test-peer", signature="test-signature", data="test-data" ) self.assertTrue(result) # Verify trust chain is recorded trust_chains = self.vault.get_trust_chains() self.assertIn("test-peer", trust_chains) self.assertEqual(trust_chains["test-peer"]["signature"], "test-signature") self.assertEqual(trust_chains["test-peer"]["data"], "test-data") def test_verify_trust_chain_unknown_peer(self): """Test trust chain verification with unknown peer.""" result = self.vault.verify_trust_chain( peer_name="unknown-peer", signature="test-signature", data="test-data" ) self.assertFalse(result) def test_get_ca_certificate(self): """Test getting CA certificate.""" cert = self.vault.get_ca_certificate() self.assertIsInstance(cert, str) self.assertTrue(cert.startswith("-----BEGIN CERTIFICATE-----")) self.assertTrue(cert.endswith("-----END CERTIFICATE-----\n")) def test_get_status(self): """Test getting vault status.""" status = self.vault.get_status() self.assertIsInstance(status, dict) self.assertIn("ca_configured", status) self.assertIn("age_configured", status) self.assertIn("certificates_count", status) self.assertIn("trusted_keys_count", status) self.assertIn("trust_chains_count", status) self.assertIn("certificates", status) self.assertIn("trusted_keys", status) self.assertIn("ca_certificate", status) self.assertIn("age_public_key", status) self.assertTrue(status["ca_configured"]) self.assertIsInstance(status["certificates"], list) self.assertIsInstance(status["trusted_keys"], list) # Remove test_encrypt_file_with_age, test_decrypt_file_with_age, and any other Age-related tests def test_certificate_with_sans(self): """Test certificate generation with Subject Alternative Names.""" cert_info = self.vault.generate_certificate( common_name="sans.example.com", domains=["sans.example.com", "www.sans.example.com", "api.sans.example.com"] ) self.assertEqual(len(cert_info["domains"]), 3) self.assertIn("sans.example.com", cert_info["domains"]) self.assertIn("www.sans.example.com", cert_info["domains"]) self.assertIn("api.sans.example.com", cert_info["domains"]) def test_multiple_certificates(self): """Test managing multiple certificates.""" # Generate multiple certificates cert1 = self.vault.generate_certificate("cert1.example.com") cert2 = self.vault.generate_certificate("cert2.example.com") cert3 = self.vault.generate_certificate("cert3.example.com") # List all certificates certificates = self.vault.list_certificates() self.assertEqual(len(certificates), 3) # Verify all certificates are listed common_names = [cert["common_name"] for cert in certificates] self.assertIn("cert1.example.com", common_names) self.assertIn("cert2.example.com", common_names) self.assertIn("cert3.example.com", common_names) def test_trust_levels(self): """Test different trust levels.""" # Add keys with different trust levels self.vault.add_trusted_key("direct-peer", "age1direct", "direct") self.vault.add_trusted_key("indirect-peer", "age1indirect", "indirect") self.vault.add_trusted_key("verified-peer", "age1verified", "verified") trusted_keys = self.vault.get_trusted_keys() self.assertEqual(trusted_keys["direct-peer"]["trust_level"], "direct") self.assertEqual(trusted_keys["indirect-peer"]["trust_level"], "indirect") self.assertEqual(trusted_keys["verified-peer"]["trust_level"], "verified") def test_trust_chains_persistence(self): """Test that trust chains are persisted.""" # Add a trusted key self.vault.add_trusted_key("test-peer", "age1testkey") # Verify trust chain self.vault.verify_trust_chain("test-peer", "sig1", "data1") self.vault.verify_trust_chain("test-peer", "sig2", "data2") # Create new vault instance (should load from disk) new_vault = VaultManager(self.config_dir, self.data_dir) # Verify trust chains are loaded trust_chains = new_vault.get_trust_chains() self.assertIn("test-peer", trust_chains) self.assertEqual(trust_chains["test-peer"]["signature"], "sig2") self.assertEqual(trust_chains["test-peer"]["data"], "data2") def test_secrets_management(self): # Store secret self.assertTrue(self.vault.store_secret('API_KEY', 'supersecret')) # Get secret self.assertEqual(self.vault.get_secret('API_KEY'), 'supersecret') # List secrets self.assertIn('API_KEY', self.vault.list_secrets()) # Delete secret self.assertTrue(self.vault.delete_secret('API_KEY')) # Secret should be gone self.assertIsNone(self.vault.get_secret('API_KEY')) self.assertNotIn('API_KEY', self.vault.list_secrets()) class TestVaultManagerIntegration(unittest.TestCase): """Integration tests for VaultManager.""" def setUp(self): """Set up test environment.""" self.test_dir = tempfile.mkdtemp() self.config_dir = os.path.join(self.test_dir, "config") self.data_dir = os.path.join(self.test_dir, "data") os.makedirs(self.config_dir, exist_ok=True) os.makedirs(self.data_dir, exist_ok=True) # Mock Age subprocess calls self.age_patcher = patch('subprocess.run') self.mock_age = self.age_patcher.start() # Mock Age key generation output mock_result = MagicMock() mock_result.stdout = "age1testkey123456789\n" self.mock_age.return_value = mock_result def tearDown(self): """Clean up test environment.""" self.age_patcher.stop() shutil.rmtree(self.test_dir) def test_full_certificate_lifecycle(self): """Test complete certificate lifecycle.""" vault = VaultManager(self.config_dir, self.data_dir) # Generate certificate cert_info = vault.generate_certificate("lifecycle.example.com") self.assertTrue(cert_info["cert_file"]) self.assertTrue(cert_info["key_file"]) # List certificates certificates = vault.list_certificates() self.assertEqual(len(certificates), 1) self.assertEqual(certificates[0]["common_name"], "lifecycle.example.com") # Revoke certificate result = vault.revoke_certificate("lifecycle.example.com") self.assertTrue(result) # Verify certificate is removed certificates = vault.list_certificates() self.assertEqual(len(certificates), 0) def test_full_trust_lifecycle(self): """Test complete trust lifecycle.""" vault = VaultManager(self.config_dir, self.data_dir) # Add trusted key result = vault.add_trusted_key("trust-peer", "age1trustkey") self.assertTrue(result) # Verify trust chain result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data") self.assertTrue(result) # Remove trusted key result = vault.remove_trusted_key("trust-peer") self.assertTrue(result) # Verify trust chain verification fails result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data") self.assertFalse(result) if __name__ == '__main__': unittest.main()