This commit is contained in:
Constantin
2025-09-12 23:04:52 +03:00
commit 2277b11563
127 changed files with 23640 additions and 0 deletions
+395
View File
@@ -0,0 +1,395 @@
#!/usr/bin/env python3
"""
Unit tests for VaultManager
Tests secure certificate management, trust systems, and Age encryption.
"""
import sys
from pathlib import Path
# Add api directory to path
api_dir = Path(__file__).parent.parent / 'api'
sys.path.insert(0, str(api_dir))
import unittest
import tempfile
import shutil
import os
import sys
from pathlib import Path
from unittest.mock import patch, MagicMock
import json
from datetime import datetime, timedelta
import subprocess
from vault_manager import VaultManager
class TestVaultManager(unittest.TestCase):
"""Test cases for VaultManager."""
def setUp(self):
"""Set up test environment."""
self.test_dir = tempfile.mkdtemp()
self.config_dir = os.path.join(self.test_dir, "config")
self.data_dir = os.path.join(self.test_dir, "data")
os.makedirs(self.config_dir, exist_ok=True)
os.makedirs(self.data_dir, exist_ok=True)
# Mock Age subprocess calls
self.age_patcher = patch('subprocess.run')
self.mock_age = self.age_patcher.start()
# Mock Age key generation output
mock_result = MagicMock()
mock_result.stdout = "age1testkey123456789\n"
mock_result.returncode = 0
self.mock_age.return_value = mock_result
self.vault = VaultManager(self.config_dir, self.data_dir)
# If Age keys were created (mocked), ensure the files exist
# (Removed Age key checks; Fernet is now used)
def tearDown(self):
"""Clean up test environment."""
self.age_patcher.stop()
shutil.rmtree(self.test_dir)
def test_init_creates_directories(self):
"""Test that initialization creates required directories."""
vault_dir = Path(self.data_dir) / "vault"
ca_dir = vault_dir / "ca"
certs_dir = vault_dir / "certs"
keys_dir = vault_dir / "keys"
trust_dir = vault_dir / "trust"
self.assertTrue(vault_dir.exists())
self.assertTrue(ca_dir.exists())
self.assertTrue(certs_dir.exists())
self.assertTrue(keys_dir.exists())
self.assertTrue(trust_dir.exists())
def test_ca_creation(self):
"""Test CA creation."""
self.assertTrue(self.vault.ca_key_file.exists())
self.assertTrue(self.vault.ca_cert_file.exists())
# Verify CA certificate properties
with open(self.vault.ca_cert_file, "rb") as f:
from cryptography import x509
cert = x509.load_pem_x509_certificate(f.read())
# Check basic constraints
basic_constraints = cert.extensions.get_extension_for_oid(
x509.oid.ExtensionOID.BASIC_CONSTRAINTS
)
self.assertTrue(basic_constraints.value.ca)
def test_generate_certificate(self):
"""Test certificate generation."""
cert_info = self.vault.generate_certificate(
common_name="test.example.com",
domains=["test.example.com", "www.test.example.com"],
key_size=2048,
days=365
)
self.assertEqual(cert_info["common_name"], "test.example.com")
self.assertEqual(cert_info["domains"], ["test.example.com", "www.test.example.com"])
self.assertTrue(cert_info["cert_file"])
self.assertTrue(cert_info["key_file"])
self.assertTrue(cert_info["encrypted"])
# Verify certificate file exists
cert_file = Path(cert_info["cert_file"])
key_file = Path(cert_info["key_file"])
self.assertTrue(cert_file.exists())
self.assertTrue(key_file.exists())
def test_generate_certificate_without_domains(self):
"""Test certificate generation without domains."""
cert_info = self.vault.generate_certificate(
common_name="simple.example.com"
)
self.assertEqual(cert_info["common_name"], "simple.example.com")
self.assertEqual(cert_info["domains"], [])
def test_list_certificates(self):
"""Test listing certificates."""
# Generate a test certificate
self.vault.generate_certificate("test.example.com")
certificates = self.vault.list_certificates()
self.assertEqual(len(certificates), 1)
self.assertEqual(certificates[0]["common_name"], "test.example.com")
self.assertFalse(certificates[0]["expired"])
def test_revoke_certificate(self):
"""Test certificate revocation."""
# Generate a test certificate
self.vault.generate_certificate("test.example.com")
# Verify certificate exists
cert_file = self.vault.certs_dir / "test.example.com.crt"
key_file = self.vault.certs_dir / "test.example.com.key"
self.assertTrue(cert_file.exists())
self.assertTrue(key_file.exists())
# Revoke certificate
result = self.vault.revoke_certificate("test.example.com")
self.assertTrue(result)
# Verify files are removed
self.assertFalse(cert_file.exists())
self.assertFalse(key_file.exists())
def test_revoke_nonexistent_certificate(self):
"""Test revoking non-existent certificate."""
result = self.vault.revoke_certificate("nonexistent.example.com")
self.assertTrue(result) # Should not raise exception
def test_add_trusted_key(self):
"""Test adding trusted key."""
result = self.vault.add_trusted_key(
name="test-peer",
public_key="age1testkey123456789",
trust_level="direct"
)
self.assertTrue(result)
# Verify key is added
trusted_keys = self.vault.get_trusted_keys()
self.assertIn("test-peer", trusted_keys)
self.assertEqual(trusted_keys["test-peer"]["public_key"], "age1testkey123456789")
self.assertEqual(trusted_keys["test-peer"]["trust_level"], "direct")
def test_remove_trusted_key(self):
"""Test removing trusted key."""
# Add a trusted key first
self.vault.add_trusted_key("test-peer", "age1testkey123456789")
# Remove the key
result = self.vault.remove_trusted_key("test-peer")
self.assertTrue(result)
# Verify key is removed
trusted_keys = self.vault.get_trusted_keys()
self.assertNotIn("test-peer", trusted_keys)
def test_remove_nonexistent_trusted_key(self):
"""Test removing non-existent trusted key."""
result = self.vault.remove_trusted_key("nonexistent-peer")
self.assertFalse(result)
def test_verify_trust_chain(self):
"""Test trust chain verification."""
# Add a trusted key first
self.vault.add_trusted_key("test-peer", "age1testkey123456789")
# Verify trust chain
result = self.vault.verify_trust_chain(
peer_name="test-peer",
signature="test-signature",
data="test-data"
)
self.assertTrue(result)
# Verify trust chain is recorded
trust_chains = self.vault.get_trust_chains()
self.assertIn("test-peer", trust_chains)
self.assertEqual(trust_chains["test-peer"]["signature"], "test-signature")
self.assertEqual(trust_chains["test-peer"]["data"], "test-data")
def test_verify_trust_chain_unknown_peer(self):
"""Test trust chain verification with unknown peer."""
result = self.vault.verify_trust_chain(
peer_name="unknown-peer",
signature="test-signature",
data="test-data"
)
self.assertFalse(result)
def test_get_ca_certificate(self):
"""Test getting CA certificate."""
cert = self.vault.get_ca_certificate()
self.assertIsInstance(cert, str)
self.assertTrue(cert.startswith("-----BEGIN CERTIFICATE-----"))
self.assertTrue(cert.endswith("-----END CERTIFICATE-----\n"))
def test_get_status(self):
"""Test getting vault status."""
status = self.vault.get_status()
self.assertIsInstance(status, dict)
self.assertIn("ca_configured", status)
self.assertIn("age_configured", status)
self.assertIn("certificates_count", status)
self.assertIn("trusted_keys_count", status)
self.assertIn("trust_chains_count", status)
self.assertIn("certificates", status)
self.assertIn("trusted_keys", status)
self.assertIn("ca_certificate", status)
self.assertIn("age_public_key", status)
self.assertTrue(status["ca_configured"])
self.assertIsInstance(status["certificates"], list)
self.assertIsInstance(status["trusted_keys"], list)
# Remove test_encrypt_file_with_age, test_decrypt_file_with_age, and any other Age-related tests
def test_certificate_with_sans(self):
"""Test certificate generation with Subject Alternative Names."""
cert_info = self.vault.generate_certificate(
common_name="sans.example.com",
domains=["sans.example.com", "www.sans.example.com", "api.sans.example.com"]
)
self.assertEqual(len(cert_info["domains"]), 3)
self.assertIn("sans.example.com", cert_info["domains"])
self.assertIn("www.sans.example.com", cert_info["domains"])
self.assertIn("api.sans.example.com", cert_info["domains"])
def test_multiple_certificates(self):
"""Test managing multiple certificates."""
# Generate multiple certificates
cert1 = self.vault.generate_certificate("cert1.example.com")
cert2 = self.vault.generate_certificate("cert2.example.com")
cert3 = self.vault.generate_certificate("cert3.example.com")
# List all certificates
certificates = self.vault.list_certificates()
self.assertEqual(len(certificates), 3)
# Verify all certificates are listed
common_names = [cert["common_name"] for cert in certificates]
self.assertIn("cert1.example.com", common_names)
self.assertIn("cert2.example.com", common_names)
self.assertIn("cert3.example.com", common_names)
def test_trust_levels(self):
"""Test different trust levels."""
# Add keys with different trust levels
self.vault.add_trusted_key("direct-peer", "age1direct", "direct")
self.vault.add_trusted_key("indirect-peer", "age1indirect", "indirect")
self.vault.add_trusted_key("verified-peer", "age1verified", "verified")
trusted_keys = self.vault.get_trusted_keys()
self.assertEqual(trusted_keys["direct-peer"]["trust_level"], "direct")
self.assertEqual(trusted_keys["indirect-peer"]["trust_level"], "indirect")
self.assertEqual(trusted_keys["verified-peer"]["trust_level"], "verified")
def test_trust_chains_persistence(self):
"""Test that trust chains are persisted."""
# Add a trusted key
self.vault.add_trusted_key("test-peer", "age1testkey")
# Verify trust chain
self.vault.verify_trust_chain("test-peer", "sig1", "data1")
self.vault.verify_trust_chain("test-peer", "sig2", "data2")
# Create new vault instance (should load from disk)
new_vault = VaultManager(self.config_dir, self.data_dir)
# Verify trust chains are loaded
trust_chains = new_vault.get_trust_chains()
self.assertIn("test-peer", trust_chains)
self.assertEqual(trust_chains["test-peer"]["signature"], "sig2")
self.assertEqual(trust_chains["test-peer"]["data"], "data2")
def test_secrets_management(self):
# Store secret
self.assertTrue(self.vault.store_secret('API_KEY', 'supersecret'))
# Get secret
self.assertEqual(self.vault.get_secret('API_KEY'), 'supersecret')
# List secrets
self.assertIn('API_KEY', self.vault.list_secrets())
# Delete secret
self.assertTrue(self.vault.delete_secret('API_KEY'))
# Secret should be gone
self.assertIsNone(self.vault.get_secret('API_KEY'))
self.assertNotIn('API_KEY', self.vault.list_secrets())
class TestVaultManagerIntegration(unittest.TestCase):
"""Integration tests for VaultManager."""
def setUp(self):
"""Set up test environment."""
self.test_dir = tempfile.mkdtemp()
self.config_dir = os.path.join(self.test_dir, "config")
self.data_dir = os.path.join(self.test_dir, "data")
os.makedirs(self.config_dir, exist_ok=True)
os.makedirs(self.data_dir, exist_ok=True)
# Mock Age subprocess calls
self.age_patcher = patch('subprocess.run')
self.mock_age = self.age_patcher.start()
# Mock Age key generation output
mock_result = MagicMock()
mock_result.stdout = "age1testkey123456789\n"
self.mock_age.return_value = mock_result
def tearDown(self):
"""Clean up test environment."""
self.age_patcher.stop()
shutil.rmtree(self.test_dir)
def test_full_certificate_lifecycle(self):
"""Test complete certificate lifecycle."""
vault = VaultManager(self.config_dir, self.data_dir)
# Generate certificate
cert_info = vault.generate_certificate("lifecycle.example.com")
self.assertTrue(cert_info["cert_file"])
self.assertTrue(cert_info["key_file"])
# List certificates
certificates = vault.list_certificates()
self.assertEqual(len(certificates), 1)
self.assertEqual(certificates[0]["common_name"], "lifecycle.example.com")
# Revoke certificate
result = vault.revoke_certificate("lifecycle.example.com")
self.assertTrue(result)
# Verify certificate is removed
certificates = vault.list_certificates()
self.assertEqual(len(certificates), 0)
def test_full_trust_lifecycle(self):
"""Test complete trust lifecycle."""
vault = VaultManager(self.config_dir, self.data_dir)
# Add trusted key
result = vault.add_trusted_key("trust-peer", "age1trustkey")
self.assertTrue(result)
# Verify trust chain
result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data")
self.assertTrue(result)
# Remove trusted key
result = vault.remove_trusted_key("trust-peer")
self.assertTrue(result)
# Verify trust chain verification fails
result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data")
self.assertFalse(result)
if __name__ == '__main__':
unittest.main()