395 lines
15 KiB
Python
395 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Unit tests for VaultManager
|
|
|
|
Tests secure certificate management, trust systems, and Age encryption.
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add api directory to path
|
|
api_dir = Path(__file__).parent.parent / 'api'
|
|
sys.path.insert(0, str(api_dir))
|
|
import unittest
|
|
import tempfile
|
|
import shutil
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
import subprocess
|
|
|
|
from vault_manager import VaultManager
|
|
|
|
|
|
class TestVaultManager(unittest.TestCase):
|
|
"""Test cases for VaultManager."""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment."""
|
|
self.test_dir = tempfile.mkdtemp()
|
|
self.config_dir = os.path.join(self.test_dir, "config")
|
|
self.data_dir = os.path.join(self.test_dir, "data")
|
|
|
|
os.makedirs(self.config_dir, exist_ok=True)
|
|
os.makedirs(self.data_dir, exist_ok=True)
|
|
|
|
# Mock Age subprocess calls
|
|
self.age_patcher = patch('subprocess.run')
|
|
self.mock_age = self.age_patcher.start()
|
|
|
|
# Mock Age key generation output
|
|
mock_result = MagicMock()
|
|
mock_result.stdout = "age1testkey123456789\n"
|
|
mock_result.returncode = 0
|
|
self.mock_age.return_value = mock_result
|
|
|
|
self.vault = VaultManager(self.config_dir, self.data_dir)
|
|
|
|
# If Age keys were created (mocked), ensure the files exist
|
|
# (Removed Age key checks; Fernet is now used)
|
|
|
|
def tearDown(self):
|
|
"""Clean up test environment."""
|
|
self.age_patcher.stop()
|
|
shutil.rmtree(self.test_dir)
|
|
|
|
def test_init_creates_directories(self):
|
|
"""Test that initialization creates required directories."""
|
|
vault_dir = Path(self.data_dir) / "vault"
|
|
ca_dir = vault_dir / "ca"
|
|
certs_dir = vault_dir / "certs"
|
|
keys_dir = vault_dir / "keys"
|
|
trust_dir = vault_dir / "trust"
|
|
|
|
self.assertTrue(vault_dir.exists())
|
|
self.assertTrue(ca_dir.exists())
|
|
self.assertTrue(certs_dir.exists())
|
|
self.assertTrue(keys_dir.exists())
|
|
self.assertTrue(trust_dir.exists())
|
|
|
|
def test_ca_creation(self):
|
|
"""Test CA creation."""
|
|
self.assertTrue(self.vault.ca_key_file.exists())
|
|
self.assertTrue(self.vault.ca_cert_file.exists())
|
|
|
|
# Verify CA certificate properties
|
|
with open(self.vault.ca_cert_file, "rb") as f:
|
|
from cryptography import x509
|
|
cert = x509.load_pem_x509_certificate(f.read())
|
|
|
|
# Check basic constraints
|
|
basic_constraints = cert.extensions.get_extension_for_oid(
|
|
x509.oid.ExtensionOID.BASIC_CONSTRAINTS
|
|
)
|
|
self.assertTrue(basic_constraints.value.ca)
|
|
|
|
def test_generate_certificate(self):
|
|
"""Test certificate generation."""
|
|
cert_info = self.vault.generate_certificate(
|
|
common_name="test.example.com",
|
|
domains=["test.example.com", "www.test.example.com"],
|
|
key_size=2048,
|
|
days=365
|
|
)
|
|
|
|
self.assertEqual(cert_info["common_name"], "test.example.com")
|
|
self.assertEqual(cert_info["domains"], ["test.example.com", "www.test.example.com"])
|
|
self.assertTrue(cert_info["cert_file"])
|
|
self.assertTrue(cert_info["key_file"])
|
|
self.assertTrue(cert_info["encrypted"])
|
|
|
|
# Verify certificate file exists
|
|
cert_file = Path(cert_info["cert_file"])
|
|
key_file = Path(cert_info["key_file"])
|
|
|
|
self.assertTrue(cert_file.exists())
|
|
self.assertTrue(key_file.exists())
|
|
|
|
def test_generate_certificate_without_domains(self):
|
|
"""Test certificate generation without domains."""
|
|
cert_info = self.vault.generate_certificate(
|
|
common_name="simple.example.com"
|
|
)
|
|
|
|
self.assertEqual(cert_info["common_name"], "simple.example.com")
|
|
self.assertEqual(cert_info["domains"], [])
|
|
|
|
def test_list_certificates(self):
|
|
"""Test listing certificates."""
|
|
# Generate a test certificate
|
|
self.vault.generate_certificate("test.example.com")
|
|
|
|
certificates = self.vault.list_certificates()
|
|
|
|
self.assertEqual(len(certificates), 1)
|
|
self.assertEqual(certificates[0]["common_name"], "test.example.com")
|
|
self.assertFalse(certificates[0]["expired"])
|
|
|
|
def test_revoke_certificate(self):
|
|
"""Test certificate revocation."""
|
|
# Generate a test certificate
|
|
self.vault.generate_certificate("test.example.com")
|
|
|
|
# Verify certificate exists
|
|
cert_file = self.vault.certs_dir / "test.example.com.crt"
|
|
key_file = self.vault.certs_dir / "test.example.com.key"
|
|
|
|
self.assertTrue(cert_file.exists())
|
|
self.assertTrue(key_file.exists())
|
|
|
|
# Revoke certificate
|
|
result = self.vault.revoke_certificate("test.example.com")
|
|
self.assertTrue(result)
|
|
|
|
# Verify files are removed
|
|
self.assertFalse(cert_file.exists())
|
|
self.assertFalse(key_file.exists())
|
|
|
|
def test_revoke_nonexistent_certificate(self):
|
|
"""Test revoking non-existent certificate."""
|
|
result = self.vault.revoke_certificate("nonexistent.example.com")
|
|
self.assertTrue(result) # Should not raise exception
|
|
|
|
def test_add_trusted_key(self):
|
|
"""Test adding trusted key."""
|
|
result = self.vault.add_trusted_key(
|
|
name="test-peer",
|
|
public_key="age1testkey123456789",
|
|
trust_level="direct"
|
|
)
|
|
|
|
self.assertTrue(result)
|
|
|
|
# Verify key is added
|
|
trusted_keys = self.vault.get_trusted_keys()
|
|
self.assertIn("test-peer", trusted_keys)
|
|
self.assertEqual(trusted_keys["test-peer"]["public_key"], "age1testkey123456789")
|
|
self.assertEqual(trusted_keys["test-peer"]["trust_level"], "direct")
|
|
|
|
def test_remove_trusted_key(self):
|
|
"""Test removing trusted key."""
|
|
# Add a trusted key first
|
|
self.vault.add_trusted_key("test-peer", "age1testkey123456789")
|
|
|
|
# Remove the key
|
|
result = self.vault.remove_trusted_key("test-peer")
|
|
self.assertTrue(result)
|
|
|
|
# Verify key is removed
|
|
trusted_keys = self.vault.get_trusted_keys()
|
|
self.assertNotIn("test-peer", trusted_keys)
|
|
|
|
def test_remove_nonexistent_trusted_key(self):
|
|
"""Test removing non-existent trusted key."""
|
|
result = self.vault.remove_trusted_key("nonexistent-peer")
|
|
self.assertFalse(result)
|
|
|
|
def test_verify_trust_chain(self):
|
|
"""Test trust chain verification."""
|
|
# Add a trusted key first
|
|
self.vault.add_trusted_key("test-peer", "age1testkey123456789")
|
|
|
|
# Verify trust chain
|
|
result = self.vault.verify_trust_chain(
|
|
peer_name="test-peer",
|
|
signature="test-signature",
|
|
data="test-data"
|
|
)
|
|
|
|
self.assertTrue(result)
|
|
|
|
# Verify trust chain is recorded
|
|
trust_chains = self.vault.get_trust_chains()
|
|
self.assertIn("test-peer", trust_chains)
|
|
self.assertEqual(trust_chains["test-peer"]["signature"], "test-signature")
|
|
self.assertEqual(trust_chains["test-peer"]["data"], "test-data")
|
|
|
|
def test_verify_trust_chain_unknown_peer(self):
|
|
"""Test trust chain verification with unknown peer."""
|
|
result = self.vault.verify_trust_chain(
|
|
peer_name="unknown-peer",
|
|
signature="test-signature",
|
|
data="test-data"
|
|
)
|
|
|
|
self.assertFalse(result)
|
|
|
|
def test_get_ca_certificate(self):
|
|
"""Test getting CA certificate."""
|
|
cert = self.vault.get_ca_certificate()
|
|
|
|
self.assertIsInstance(cert, str)
|
|
self.assertTrue(cert.startswith("-----BEGIN CERTIFICATE-----"))
|
|
self.assertTrue(cert.endswith("-----END CERTIFICATE-----\n"))
|
|
|
|
def test_get_status(self):
|
|
"""Test getting vault status."""
|
|
status = self.vault.get_status()
|
|
|
|
self.assertIsInstance(status, dict)
|
|
self.assertIn("ca_configured", status)
|
|
self.assertIn("age_configured", status)
|
|
self.assertIn("certificates_count", status)
|
|
self.assertIn("trusted_keys_count", status)
|
|
self.assertIn("trust_chains_count", status)
|
|
self.assertIn("certificates", status)
|
|
self.assertIn("trusted_keys", status)
|
|
self.assertIn("ca_certificate", status)
|
|
self.assertIn("age_public_key", status)
|
|
|
|
self.assertTrue(status["ca_configured"])
|
|
self.assertIsInstance(status["certificates"], list)
|
|
self.assertIsInstance(status["trusted_keys"], list)
|
|
|
|
# Remove test_encrypt_file_with_age, test_decrypt_file_with_age, and any other Age-related tests
|
|
|
|
def test_certificate_with_sans(self):
|
|
"""Test certificate generation with Subject Alternative Names."""
|
|
cert_info = self.vault.generate_certificate(
|
|
common_name="sans.example.com",
|
|
domains=["sans.example.com", "www.sans.example.com", "api.sans.example.com"]
|
|
)
|
|
|
|
self.assertEqual(len(cert_info["domains"]), 3)
|
|
self.assertIn("sans.example.com", cert_info["domains"])
|
|
self.assertIn("www.sans.example.com", cert_info["domains"])
|
|
self.assertIn("api.sans.example.com", cert_info["domains"])
|
|
|
|
def test_multiple_certificates(self):
|
|
"""Test managing multiple certificates."""
|
|
# Generate multiple certificates
|
|
cert1 = self.vault.generate_certificate("cert1.example.com")
|
|
cert2 = self.vault.generate_certificate("cert2.example.com")
|
|
cert3 = self.vault.generate_certificate("cert3.example.com")
|
|
|
|
# List all certificates
|
|
certificates = self.vault.list_certificates()
|
|
|
|
self.assertEqual(len(certificates), 3)
|
|
|
|
# Verify all certificates are listed
|
|
common_names = [cert["common_name"] for cert in certificates]
|
|
self.assertIn("cert1.example.com", common_names)
|
|
self.assertIn("cert2.example.com", common_names)
|
|
self.assertIn("cert3.example.com", common_names)
|
|
|
|
def test_trust_levels(self):
|
|
"""Test different trust levels."""
|
|
# Add keys with different trust levels
|
|
self.vault.add_trusted_key("direct-peer", "age1direct", "direct")
|
|
self.vault.add_trusted_key("indirect-peer", "age1indirect", "indirect")
|
|
self.vault.add_trusted_key("verified-peer", "age1verified", "verified")
|
|
|
|
trusted_keys = self.vault.get_trusted_keys()
|
|
|
|
self.assertEqual(trusted_keys["direct-peer"]["trust_level"], "direct")
|
|
self.assertEqual(trusted_keys["indirect-peer"]["trust_level"], "indirect")
|
|
self.assertEqual(trusted_keys["verified-peer"]["trust_level"], "verified")
|
|
|
|
def test_trust_chains_persistence(self):
|
|
"""Test that trust chains are persisted."""
|
|
# Add a trusted key
|
|
self.vault.add_trusted_key("test-peer", "age1testkey")
|
|
|
|
# Verify trust chain
|
|
self.vault.verify_trust_chain("test-peer", "sig1", "data1")
|
|
self.vault.verify_trust_chain("test-peer", "sig2", "data2")
|
|
|
|
# Create new vault instance (should load from disk)
|
|
new_vault = VaultManager(self.config_dir, self.data_dir)
|
|
|
|
# Verify trust chains are loaded
|
|
trust_chains = new_vault.get_trust_chains()
|
|
self.assertIn("test-peer", trust_chains)
|
|
self.assertEqual(trust_chains["test-peer"]["signature"], "sig2")
|
|
self.assertEqual(trust_chains["test-peer"]["data"], "data2")
|
|
|
|
def test_secrets_management(self):
|
|
# Store secret
|
|
self.assertTrue(self.vault.store_secret('API_KEY', 'supersecret'))
|
|
# Get secret
|
|
self.assertEqual(self.vault.get_secret('API_KEY'), 'supersecret')
|
|
# List secrets
|
|
self.assertIn('API_KEY', self.vault.list_secrets())
|
|
# Delete secret
|
|
self.assertTrue(self.vault.delete_secret('API_KEY'))
|
|
# Secret should be gone
|
|
self.assertIsNone(self.vault.get_secret('API_KEY'))
|
|
self.assertNotIn('API_KEY', self.vault.list_secrets())
|
|
|
|
|
|
class TestVaultManagerIntegration(unittest.TestCase):
|
|
"""Integration tests for VaultManager."""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment."""
|
|
self.test_dir = tempfile.mkdtemp()
|
|
self.config_dir = os.path.join(self.test_dir, "config")
|
|
self.data_dir = os.path.join(self.test_dir, "data")
|
|
|
|
os.makedirs(self.config_dir, exist_ok=True)
|
|
os.makedirs(self.data_dir, exist_ok=True)
|
|
|
|
# Mock Age subprocess calls
|
|
self.age_patcher = patch('subprocess.run')
|
|
self.mock_age = self.age_patcher.start()
|
|
|
|
# Mock Age key generation output
|
|
mock_result = MagicMock()
|
|
mock_result.stdout = "age1testkey123456789\n"
|
|
self.mock_age.return_value = mock_result
|
|
|
|
def tearDown(self):
|
|
"""Clean up test environment."""
|
|
self.age_patcher.stop()
|
|
shutil.rmtree(self.test_dir)
|
|
|
|
def test_full_certificate_lifecycle(self):
|
|
"""Test complete certificate lifecycle."""
|
|
vault = VaultManager(self.config_dir, self.data_dir)
|
|
|
|
# Generate certificate
|
|
cert_info = vault.generate_certificate("lifecycle.example.com")
|
|
self.assertTrue(cert_info["cert_file"])
|
|
self.assertTrue(cert_info["key_file"])
|
|
|
|
# List certificates
|
|
certificates = vault.list_certificates()
|
|
self.assertEqual(len(certificates), 1)
|
|
self.assertEqual(certificates[0]["common_name"], "lifecycle.example.com")
|
|
|
|
# Revoke certificate
|
|
result = vault.revoke_certificate("lifecycle.example.com")
|
|
self.assertTrue(result)
|
|
|
|
# Verify certificate is removed
|
|
certificates = vault.list_certificates()
|
|
self.assertEqual(len(certificates), 0)
|
|
|
|
def test_full_trust_lifecycle(self):
|
|
"""Test complete trust lifecycle."""
|
|
vault = VaultManager(self.config_dir, self.data_dir)
|
|
|
|
# Add trusted key
|
|
result = vault.add_trusted_key("trust-peer", "age1trustkey")
|
|
self.assertTrue(result)
|
|
|
|
# Verify trust chain
|
|
result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data")
|
|
self.assertTrue(result)
|
|
|
|
# Remove trusted key
|
|
result = vault.remove_trusted_key("trust-peer")
|
|
self.assertTrue(result)
|
|
|
|
# Verify trust chain verification fails
|
|
result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data")
|
|
self.assertFalse(result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |