init
This commit is contained in:
@@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unit tests for VaultManager
|
||||
|
||||
Tests secure certificate management, trust systems, and Age encryption.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add api directory to path
|
||||
api_dir = Path(__file__).parent.parent / 'api'
|
||||
sys.path.insert(0, str(api_dir))
|
||||
import unittest
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import subprocess
|
||||
|
||||
from vault_manager import VaultManager
|
||||
|
||||
|
||||
class TestVaultManager(unittest.TestCase):
|
||||
"""Test cases for VaultManager."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.config_dir = os.path.join(self.test_dir, "config")
|
||||
self.data_dir = os.path.join(self.test_dir, "data")
|
||||
|
||||
os.makedirs(self.config_dir, exist_ok=True)
|
||||
os.makedirs(self.data_dir, exist_ok=True)
|
||||
|
||||
# Mock Age subprocess calls
|
||||
self.age_patcher = patch('subprocess.run')
|
||||
self.mock_age = self.age_patcher.start()
|
||||
|
||||
# Mock Age key generation output
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = "age1testkey123456789\n"
|
||||
mock_result.returncode = 0
|
||||
self.mock_age.return_value = mock_result
|
||||
|
||||
self.vault = VaultManager(self.config_dir, self.data_dir)
|
||||
|
||||
# If Age keys were created (mocked), ensure the files exist
|
||||
# (Removed Age key checks; Fernet is now used)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
self.age_patcher.stop()
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_init_creates_directories(self):
|
||||
"""Test that initialization creates required directories."""
|
||||
vault_dir = Path(self.data_dir) / "vault"
|
||||
ca_dir = vault_dir / "ca"
|
||||
certs_dir = vault_dir / "certs"
|
||||
keys_dir = vault_dir / "keys"
|
||||
trust_dir = vault_dir / "trust"
|
||||
|
||||
self.assertTrue(vault_dir.exists())
|
||||
self.assertTrue(ca_dir.exists())
|
||||
self.assertTrue(certs_dir.exists())
|
||||
self.assertTrue(keys_dir.exists())
|
||||
self.assertTrue(trust_dir.exists())
|
||||
|
||||
def test_ca_creation(self):
|
||||
"""Test CA creation."""
|
||||
self.assertTrue(self.vault.ca_key_file.exists())
|
||||
self.assertTrue(self.vault.ca_cert_file.exists())
|
||||
|
||||
# Verify CA certificate properties
|
||||
with open(self.vault.ca_cert_file, "rb") as f:
|
||||
from cryptography import x509
|
||||
cert = x509.load_pem_x509_certificate(f.read())
|
||||
|
||||
# Check basic constraints
|
||||
basic_constraints = cert.extensions.get_extension_for_oid(
|
||||
x509.oid.ExtensionOID.BASIC_CONSTRAINTS
|
||||
)
|
||||
self.assertTrue(basic_constraints.value.ca)
|
||||
|
||||
def test_generate_certificate(self):
|
||||
"""Test certificate generation."""
|
||||
cert_info = self.vault.generate_certificate(
|
||||
common_name="test.example.com",
|
||||
domains=["test.example.com", "www.test.example.com"],
|
||||
key_size=2048,
|
||||
days=365
|
||||
)
|
||||
|
||||
self.assertEqual(cert_info["common_name"], "test.example.com")
|
||||
self.assertEqual(cert_info["domains"], ["test.example.com", "www.test.example.com"])
|
||||
self.assertTrue(cert_info["cert_file"])
|
||||
self.assertTrue(cert_info["key_file"])
|
||||
self.assertTrue(cert_info["encrypted"])
|
||||
|
||||
# Verify certificate file exists
|
||||
cert_file = Path(cert_info["cert_file"])
|
||||
key_file = Path(cert_info["key_file"])
|
||||
|
||||
self.assertTrue(cert_file.exists())
|
||||
self.assertTrue(key_file.exists())
|
||||
|
||||
def test_generate_certificate_without_domains(self):
|
||||
"""Test certificate generation without domains."""
|
||||
cert_info = self.vault.generate_certificate(
|
||||
common_name="simple.example.com"
|
||||
)
|
||||
|
||||
self.assertEqual(cert_info["common_name"], "simple.example.com")
|
||||
self.assertEqual(cert_info["domains"], [])
|
||||
|
||||
def test_list_certificates(self):
|
||||
"""Test listing certificates."""
|
||||
# Generate a test certificate
|
||||
self.vault.generate_certificate("test.example.com")
|
||||
|
||||
certificates = self.vault.list_certificates()
|
||||
|
||||
self.assertEqual(len(certificates), 1)
|
||||
self.assertEqual(certificates[0]["common_name"], "test.example.com")
|
||||
self.assertFalse(certificates[0]["expired"])
|
||||
|
||||
def test_revoke_certificate(self):
|
||||
"""Test certificate revocation."""
|
||||
# Generate a test certificate
|
||||
self.vault.generate_certificate("test.example.com")
|
||||
|
||||
# Verify certificate exists
|
||||
cert_file = self.vault.certs_dir / "test.example.com.crt"
|
||||
key_file = self.vault.certs_dir / "test.example.com.key"
|
||||
|
||||
self.assertTrue(cert_file.exists())
|
||||
self.assertTrue(key_file.exists())
|
||||
|
||||
# Revoke certificate
|
||||
result = self.vault.revoke_certificate("test.example.com")
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify files are removed
|
||||
self.assertFalse(cert_file.exists())
|
||||
self.assertFalse(key_file.exists())
|
||||
|
||||
def test_revoke_nonexistent_certificate(self):
|
||||
"""Test revoking non-existent certificate."""
|
||||
result = self.vault.revoke_certificate("nonexistent.example.com")
|
||||
self.assertTrue(result) # Should not raise exception
|
||||
|
||||
def test_add_trusted_key(self):
|
||||
"""Test adding trusted key."""
|
||||
result = self.vault.add_trusted_key(
|
||||
name="test-peer",
|
||||
public_key="age1testkey123456789",
|
||||
trust_level="direct"
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify key is added
|
||||
trusted_keys = self.vault.get_trusted_keys()
|
||||
self.assertIn("test-peer", trusted_keys)
|
||||
self.assertEqual(trusted_keys["test-peer"]["public_key"], "age1testkey123456789")
|
||||
self.assertEqual(trusted_keys["test-peer"]["trust_level"], "direct")
|
||||
|
||||
def test_remove_trusted_key(self):
|
||||
"""Test removing trusted key."""
|
||||
# Add a trusted key first
|
||||
self.vault.add_trusted_key("test-peer", "age1testkey123456789")
|
||||
|
||||
# Remove the key
|
||||
result = self.vault.remove_trusted_key("test-peer")
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify key is removed
|
||||
trusted_keys = self.vault.get_trusted_keys()
|
||||
self.assertNotIn("test-peer", trusted_keys)
|
||||
|
||||
def test_remove_nonexistent_trusted_key(self):
|
||||
"""Test removing non-existent trusted key."""
|
||||
result = self.vault.remove_trusted_key("nonexistent-peer")
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_verify_trust_chain(self):
|
||||
"""Test trust chain verification."""
|
||||
# Add a trusted key first
|
||||
self.vault.add_trusted_key("test-peer", "age1testkey123456789")
|
||||
|
||||
# Verify trust chain
|
||||
result = self.vault.verify_trust_chain(
|
||||
peer_name="test-peer",
|
||||
signature="test-signature",
|
||||
data="test-data"
|
||||
)
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify trust chain is recorded
|
||||
trust_chains = self.vault.get_trust_chains()
|
||||
self.assertIn("test-peer", trust_chains)
|
||||
self.assertEqual(trust_chains["test-peer"]["signature"], "test-signature")
|
||||
self.assertEqual(trust_chains["test-peer"]["data"], "test-data")
|
||||
|
||||
def test_verify_trust_chain_unknown_peer(self):
|
||||
"""Test trust chain verification with unknown peer."""
|
||||
result = self.vault.verify_trust_chain(
|
||||
peer_name="unknown-peer",
|
||||
signature="test-signature",
|
||||
data="test-data"
|
||||
)
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_get_ca_certificate(self):
|
||||
"""Test getting CA certificate."""
|
||||
cert = self.vault.get_ca_certificate()
|
||||
|
||||
self.assertIsInstance(cert, str)
|
||||
self.assertTrue(cert.startswith("-----BEGIN CERTIFICATE-----"))
|
||||
self.assertTrue(cert.endswith("-----END CERTIFICATE-----\n"))
|
||||
|
||||
def test_get_status(self):
|
||||
"""Test getting vault status."""
|
||||
status = self.vault.get_status()
|
||||
|
||||
self.assertIsInstance(status, dict)
|
||||
self.assertIn("ca_configured", status)
|
||||
self.assertIn("age_configured", status)
|
||||
self.assertIn("certificates_count", status)
|
||||
self.assertIn("trusted_keys_count", status)
|
||||
self.assertIn("trust_chains_count", status)
|
||||
self.assertIn("certificates", status)
|
||||
self.assertIn("trusted_keys", status)
|
||||
self.assertIn("ca_certificate", status)
|
||||
self.assertIn("age_public_key", status)
|
||||
|
||||
self.assertTrue(status["ca_configured"])
|
||||
self.assertIsInstance(status["certificates"], list)
|
||||
self.assertIsInstance(status["trusted_keys"], list)
|
||||
|
||||
# Remove test_encrypt_file_with_age, test_decrypt_file_with_age, and any other Age-related tests
|
||||
|
||||
def test_certificate_with_sans(self):
|
||||
"""Test certificate generation with Subject Alternative Names."""
|
||||
cert_info = self.vault.generate_certificate(
|
||||
common_name="sans.example.com",
|
||||
domains=["sans.example.com", "www.sans.example.com", "api.sans.example.com"]
|
||||
)
|
||||
|
||||
self.assertEqual(len(cert_info["domains"]), 3)
|
||||
self.assertIn("sans.example.com", cert_info["domains"])
|
||||
self.assertIn("www.sans.example.com", cert_info["domains"])
|
||||
self.assertIn("api.sans.example.com", cert_info["domains"])
|
||||
|
||||
def test_multiple_certificates(self):
|
||||
"""Test managing multiple certificates."""
|
||||
# Generate multiple certificates
|
||||
cert1 = self.vault.generate_certificate("cert1.example.com")
|
||||
cert2 = self.vault.generate_certificate("cert2.example.com")
|
||||
cert3 = self.vault.generate_certificate("cert3.example.com")
|
||||
|
||||
# List all certificates
|
||||
certificates = self.vault.list_certificates()
|
||||
|
||||
self.assertEqual(len(certificates), 3)
|
||||
|
||||
# Verify all certificates are listed
|
||||
common_names = [cert["common_name"] for cert in certificates]
|
||||
self.assertIn("cert1.example.com", common_names)
|
||||
self.assertIn("cert2.example.com", common_names)
|
||||
self.assertIn("cert3.example.com", common_names)
|
||||
|
||||
def test_trust_levels(self):
|
||||
"""Test different trust levels."""
|
||||
# Add keys with different trust levels
|
||||
self.vault.add_trusted_key("direct-peer", "age1direct", "direct")
|
||||
self.vault.add_trusted_key("indirect-peer", "age1indirect", "indirect")
|
||||
self.vault.add_trusted_key("verified-peer", "age1verified", "verified")
|
||||
|
||||
trusted_keys = self.vault.get_trusted_keys()
|
||||
|
||||
self.assertEqual(trusted_keys["direct-peer"]["trust_level"], "direct")
|
||||
self.assertEqual(trusted_keys["indirect-peer"]["trust_level"], "indirect")
|
||||
self.assertEqual(trusted_keys["verified-peer"]["trust_level"], "verified")
|
||||
|
||||
def test_trust_chains_persistence(self):
|
||||
"""Test that trust chains are persisted."""
|
||||
# Add a trusted key
|
||||
self.vault.add_trusted_key("test-peer", "age1testkey")
|
||||
|
||||
# Verify trust chain
|
||||
self.vault.verify_trust_chain("test-peer", "sig1", "data1")
|
||||
self.vault.verify_trust_chain("test-peer", "sig2", "data2")
|
||||
|
||||
# Create new vault instance (should load from disk)
|
||||
new_vault = VaultManager(self.config_dir, self.data_dir)
|
||||
|
||||
# Verify trust chains are loaded
|
||||
trust_chains = new_vault.get_trust_chains()
|
||||
self.assertIn("test-peer", trust_chains)
|
||||
self.assertEqual(trust_chains["test-peer"]["signature"], "sig2")
|
||||
self.assertEqual(trust_chains["test-peer"]["data"], "data2")
|
||||
|
||||
def test_secrets_management(self):
|
||||
# Store secret
|
||||
self.assertTrue(self.vault.store_secret('API_KEY', 'supersecret'))
|
||||
# Get secret
|
||||
self.assertEqual(self.vault.get_secret('API_KEY'), 'supersecret')
|
||||
# List secrets
|
||||
self.assertIn('API_KEY', self.vault.list_secrets())
|
||||
# Delete secret
|
||||
self.assertTrue(self.vault.delete_secret('API_KEY'))
|
||||
# Secret should be gone
|
||||
self.assertIsNone(self.vault.get_secret('API_KEY'))
|
||||
self.assertNotIn('API_KEY', self.vault.list_secrets())
|
||||
|
||||
|
||||
class TestVaultManagerIntegration(unittest.TestCase):
|
||||
"""Integration tests for VaultManager."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.config_dir = os.path.join(self.test_dir, "config")
|
||||
self.data_dir = os.path.join(self.test_dir, "data")
|
||||
|
||||
os.makedirs(self.config_dir, exist_ok=True)
|
||||
os.makedirs(self.data_dir, exist_ok=True)
|
||||
|
||||
# Mock Age subprocess calls
|
||||
self.age_patcher = patch('subprocess.run')
|
||||
self.mock_age = self.age_patcher.start()
|
||||
|
||||
# Mock Age key generation output
|
||||
mock_result = MagicMock()
|
||||
mock_result.stdout = "age1testkey123456789\n"
|
||||
self.mock_age.return_value = mock_result
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
self.age_patcher.stop()
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_full_certificate_lifecycle(self):
|
||||
"""Test complete certificate lifecycle."""
|
||||
vault = VaultManager(self.config_dir, self.data_dir)
|
||||
|
||||
# Generate certificate
|
||||
cert_info = vault.generate_certificate("lifecycle.example.com")
|
||||
self.assertTrue(cert_info["cert_file"])
|
||||
self.assertTrue(cert_info["key_file"])
|
||||
|
||||
# List certificates
|
||||
certificates = vault.list_certificates()
|
||||
self.assertEqual(len(certificates), 1)
|
||||
self.assertEqual(certificates[0]["common_name"], "lifecycle.example.com")
|
||||
|
||||
# Revoke certificate
|
||||
result = vault.revoke_certificate("lifecycle.example.com")
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify certificate is removed
|
||||
certificates = vault.list_certificates()
|
||||
self.assertEqual(len(certificates), 0)
|
||||
|
||||
def test_full_trust_lifecycle(self):
|
||||
"""Test complete trust lifecycle."""
|
||||
vault = VaultManager(self.config_dir, self.data_dir)
|
||||
|
||||
# Add trusted key
|
||||
result = vault.add_trusted_key("trust-peer", "age1trustkey")
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify trust chain
|
||||
result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data")
|
||||
self.assertTrue(result)
|
||||
|
||||
# Remove trusted key
|
||||
result = vault.remove_trusted_key("trust-peer")
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify trust chain verification fails
|
||||
result = vault.verify_trust_chain("trust-peer", "trust-sig", "trust-data")
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user