From 5f6640d8f37b38785a1d1339d95bf96605fee50b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Benouarets?= Date: Tue, 30 Sep 2025 11:44:51 +0200 Subject: [PATCH] feat: implement certificate authority service - Add CertificateAuthorityService for CA management - Support creation of root and intermediate CAs - Implement CA certificate generation with proper X.509 attributes - Add CA validation and verification functionality - Include comprehensive test coverage for CA operations - Support multiple CA types and configurations - Add proper error handling and logging --- certificate/authority.go | 637 +++++++++++++++++++++++++++++++ certificate/authority_test.go | 690 ++++++++++++++++++++++++++++++++++ 2 files changed, 1327 insertions(+) create mode 100644 certificate/authority.go create mode 100644 certificate/authority_test.go diff --git a/certificate/authority.go b/certificate/authority.go new file mode 100644 index 0000000..3849a7d --- /dev/null +++ b/certificate/authority.go @@ -0,0 +1,637 @@ +package certificate + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "path/filepath" + "time" + + "git.secnex.io/secnex/certman/certificate/utils" + "git.secnex.io/secnex/certman/models" + "git.secnex.io/secnex/certman/repositories" + "github.com/google/uuid" + "gorm.io/gorm" +) + +// CertificateAuthorityService handles certificate authority operations +type CertificateAuthorityService struct { + caRepo *repositories.CertificateAuthorityRepository + orgRepo *repositories.OrganizationRepository + certDir string + privateDir string +} + +// NewCertificateAuthorityService creates a new certificate authority service +func NewCertificateAuthorityService( + db *gorm.DB, + certDir string, + privateDir string, +) *CertificateAuthorityService { + return &CertificateAuthorityService{ + caRepo: repositories.NewCertificateAuthorityRepository(db), + orgRepo: repositories.NewOrganizationRepository(db), + certDir: certDir, + privateDir: privateDir, + } +} + +// CreateRootCA creates a new root certificate authority +func (s *CertificateAuthorityService) CreateRootCA(req *CreateRootCARequest) (*models.CertificateAuthority, error) { + // Validate request + if err := s.validateRootCARequest(req); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + // Check if organization exists + org, err := s.orgRepo.GetByID(req.OrganizationID.String()) + if err != nil { + return nil, fmt.Errorf("organization not found: %w", err) + } + + // Check if root CA already exists for this organization + existingCAs, err := s.caRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to check existing CAs: %w", err) + } + + for _, ca := range existingCAs { + if ca.Root && ca.OrganizationID == req.OrganizationID { + return nil, fmt.Errorf("root CA already exists for organization %s", org.Name) + } + } + + // Create certificate configuration + var config *utils.CertificateConfig + if req.ValidityYears > 0 { + config = utils.CACertificateConfigWithValidity(true, req.ValidityYears) // true = isRoot + } else { + config = utils.CACertificateConfig(true) // true = isRoot + } + config.CommonName = req.CommonName + config.Organization = org.Name + config.OrganizationalUnit = req.OrganizationalUnit + config.Country = req.Country + config.State = req.State + config.Locality = req.Locality + config.Street = req.Street + config.PostalCode = req.PostalCode + config.Email = req.Email + + // Validate configuration + if err := utils.ValidateCertificateConfig(config); err != nil { + return nil, fmt.Errorf("invalid certificate configuration: %w", err) + } + + // Generate certificate and private key + generator := utils.NewCertificateGenerator(config) + cert, privateKey, err := generator.GenerateSelfSignedCertificate() + if err != nil { + return nil, fmt.Errorf("failed to generate root CA certificate: %w", err) + } + + // Save certificate and private key to files + certFileID, err := s.saveCertificate(cert, "root-ca") + if err != nil { + return nil, fmt.Errorf("failed to save certificate: %w", err) + } + + privateKeyFileID, err := s.savePrivateKey(privateKey, "root-ca") + if err != nil { + return nil, fmt.Errorf("failed to save private key: %w", err) + } + + // Create CA model + ca := &models.CertificateAuthority{ + Name: req.Name, + Description: req.Description, + SerialNumber: cert.SerialNumber.String(), + AttributeCommonName: cert.Subject.CommonName, + AttributeOrganization: cert.Subject.Organization[0], + AttributeOrganizationUnit: cert.Subject.OrganizationalUnit[0], + AttributeCountry: cert.Subject.Country[0], + AttributeState: cert.Subject.Province[0], + AttributeLocality: cert.Subject.Locality[0], + AttributeStreet: cert.Subject.StreetAddress[0], + AttributeEmail: req.Email, + AttributeAddress: req.Address, + AttributePostalCode: cert.Subject.PostalCode[0], + AttributeNotBefore: cert.NotBefore, + AttributeNotAfter: cert.NotAfter, + Root: true, + ParentID: nil, // Root CA has no parent + OrganizationID: req.OrganizationID, + FileID: certFileID, + PrivateKeyID: privateKeyFileID, + } + + // Save to database + createdCA, err := s.caRepo.Create(*ca) + if err != nil { + // Clean up files if database save fails + s.cleanupFiles(certFileID, privateKeyFileID) + return nil, fmt.Errorf("failed to save CA to database: %w", err) + } + + return &createdCA, nil +} + +// CreateIntermediateCA creates a new intermediate certificate authority +func (s *CertificateAuthorityService) CreateIntermediateCA(req *CreateIntermediateCARequest) (*models.CertificateAuthority, error) { + // Validate request + if err := s.validateIntermediateCARequest(req); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + // Check if organization exists + org, err := s.orgRepo.GetByID(req.OrganizationID.String()) + if err != nil { + return nil, fmt.Errorf("organization not found: %w", err) + } + + // Get parent CA + parentCA, err := s.caRepo.GetByID(req.ParentCAID.String()) + if err != nil { + return nil, fmt.Errorf("parent CA not found: %w", err) + } + + // Load parent CA certificate and private key + parentCert, parentPrivateKey, err := s.loadCACertificateAndKey(&parentCA) + if err != nil { + return nil, fmt.Errorf("failed to load parent CA certificate: %w", err) + } + + // Create certificate configuration + var config *utils.CertificateConfig + if req.ValidityYears > 0 { + config = utils.CACertificateConfigWithValidity(false, req.ValidityYears) // false = not root + } else { + config = utils.CACertificateConfig(false) // false = not root + } + config.CommonName = req.CommonName + config.Organization = org.Name + config.OrganizationalUnit = req.OrganizationalUnit + config.Country = req.Country + config.State = req.State + config.Locality = req.Locality + config.Street = req.Street + config.PostalCode = req.PostalCode + config.Email = req.Email + + // Validate configuration + if err := utils.ValidateCertificateConfig(config); err != nil { + return nil, fmt.Errorf("invalid certificate configuration: %w", err) + } + + // Generate certificate and private key + generator := utils.NewCertificateGenerator(config) + cert, privateKey, err := generator.GenerateCertificate(parentCert, parentPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to generate intermediate CA certificate: %w", err) + } + + // Save certificate and private key to files + certFileID, err := s.saveCertificate(cert, "intermediate-ca") + if err != nil { + return nil, fmt.Errorf("failed to save certificate: %w", err) + } + + privateKeyFileID, err := s.savePrivateKey(privateKey, "intermediate-ca") + if err != nil { + return nil, fmt.Errorf("failed to save private key: %w", err) + } + + // Create CA model + ca := &models.CertificateAuthority{ + Name: req.Name, + Description: req.Description, + SerialNumber: cert.SerialNumber.String(), + AttributeCommonName: cert.Subject.CommonName, + AttributeOrganization: cert.Subject.Organization[0], + AttributeOrganizationUnit: cert.Subject.OrganizationalUnit[0], + AttributeCountry: cert.Subject.Country[0], + AttributeState: cert.Subject.Province[0], + AttributeLocality: cert.Subject.Locality[0], + AttributeStreet: cert.Subject.StreetAddress[0], + AttributeEmail: req.Email, + AttributeAddress: req.Address, + AttributePostalCode: cert.Subject.PostalCode[0], + AttributeNotBefore: cert.NotBefore, + AttributeNotAfter: cert.NotAfter, + Root: false, + ParentID: &req.ParentCAID, + OrganizationID: req.OrganizationID, + FileID: certFileID, + PrivateKeyID: privateKeyFileID, + } + + // Save to database + createdCA, err := s.caRepo.Create(*ca) + if err != nil { + // Clean up files if database save fails + s.cleanupFiles(certFileID, privateKeyFileID) + return nil, fmt.Errorf("failed to save CA to database: %w", err) + } + + return &createdCA, nil +} + +// GetCA retrieves a certificate authority by ID +func (s *CertificateAuthorityService) GetCA(id string) (*models.CertificateAuthority, error) { + ca, err := s.caRepo.GetByID(id) + if err != nil { + return nil, fmt.Errorf("CA not found: %w", err) + } + return &ca, nil +} + +// GetAllCAs retrieves all certificate authorities +func (s *CertificateAuthorityService) GetAllCAs() ([]models.CertificateAuthority, error) { + cas, err := s.caRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to retrieve CAs: %w", err) + } + return cas, nil +} + +// GetRootCAs retrieves all root certificate authorities +func (s *CertificateAuthorityService) GetRootCAs() ([]models.CertificateAuthority, error) { + allCAs, err := s.caRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to retrieve CAs: %w", err) + } + + var rootCAs []models.CertificateAuthority + for _, ca := range allCAs { + if ca.Root { + rootCAs = append(rootCAs, ca) + } + } + + return rootCAs, nil +} + +// GetIntermediateCAs retrieves all intermediate certificate authorities +func (s *CertificateAuthorityService) GetIntermediateCAs() ([]models.CertificateAuthority, error) { + allCAs, err := s.caRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to retrieve CAs: %w", err) + } + + var intermediateCAs []models.CertificateAuthority + for _, ca := range allCAs { + if !ca.Root { + intermediateCAs = append(intermediateCAs, ca) + } + } + + return intermediateCAs, nil +} + +// GetCAByParent retrieves all certificate authorities under a specific parent +func (s *CertificateAuthorityService) GetCAByParent(parentID string) ([]models.CertificateAuthority, error) { + allCAs, err := s.caRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to retrieve CAs: %w", err) + } + + var childCAs []models.CertificateAuthority + for _, ca := range allCAs { + if ca.ParentID != nil && ca.ParentID.String() == parentID { + childCAs = append(childCAs, ca) + } + } + + return childCAs, nil +} + +// GetCACertificate retrieves the certificate for a CA +func (s *CertificateAuthorityService) GetCACertificate(caID string) (*x509.Certificate, error) { + ca, err := s.caRepo.GetByID(caID) + if err != nil { + return nil, fmt.Errorf("CA not found: %w", err) + } + + cert, _, err := s.loadCACertificateAndKey(&ca) + if err != nil { + return nil, fmt.Errorf("failed to load CA certificate: %w", err) + } + + return cert, nil +} + +// GetCAPrivateKey retrieves the private key for a CA +func (s *CertificateAuthorityService) GetCAPrivateKey(caID string) (interface{}, error) { + ca, err := s.caRepo.GetByID(caID) + if err != nil { + return nil, fmt.Errorf("CA not found: %w", err) + } + + _, privateKey, err := s.loadCACertificateAndKey(&ca) + if err != nil { + return nil, fmt.Errorf("failed to load CA private key: %w", err) + } + + return privateKey, nil +} + +// UpdateCA updates a certificate authority +func (s *CertificateAuthorityService) UpdateCA(ca *models.CertificateAuthority) error { + return s.caRepo.Update(*ca) +} + +// DeleteCA deletes a certificate authority +func (s *CertificateAuthorityService) DeleteCA(caID string) error { + ca, err := s.caRepo.GetByID(caID) + if err != nil { + return fmt.Errorf("CA not found: %w", err) + } + + // Check if CA has children + children, err := s.GetCAByParent(caID) + if err != nil { + return fmt.Errorf("failed to check CA children: %w", err) + } + + if len(children) > 0 { + return fmt.Errorf("cannot delete CA with children: %d intermediate CAs depend on this CA", len(children)) + } + + // Delete from database + if err := s.caRepo.Delete(caID); err != nil { + return fmt.Errorf("failed to delete CA from database: %w", err) + } + + // Clean up files + s.cleanupFiles(ca.FileID, ca.PrivateKeyID) + + return nil +} + +// RevokeCA revokes a certificate authority (marks as revoked but keeps in database) +func (s *CertificateAuthorityService) RevokeCA(caID string, reason string) error { + ca, err := s.caRepo.GetByID(caID) + if err != nil { + return fmt.Errorf("CA not found: %w", err) + } + + // Update CA with revocation information + ca.Description = fmt.Sprintf("%s [REVOKED: %s]", ca.Description, reason) + + if err := s.caRepo.Update(ca); err != nil { + return fmt.Errorf("failed to revoke CA: %w", err) + } + + return nil +} + +// ValidateCAChain validates the certificate chain for a CA +func (s *CertificateAuthorityService) ValidateCAChain(caID string) error { + ca, err := s.caRepo.GetByID(caID) + if err != nil { + return fmt.Errorf("CA not found: %w", err) + } + + // Load CA certificate + cert, _, err := s.loadCACertificateAndKey(&ca) + if err != nil { + return fmt.Errorf("failed to load CA certificate: %w", err) + } + + // If it's a root CA, validate it's self-signed + if ca.Root { + if err := cert.CheckSignatureFrom(cert); err != nil { + return fmt.Errorf("root CA certificate is not properly self-signed: %w", err) + } + return nil + } + + // For intermediate CAs, validate the chain + if ca.ParentID == nil { + return fmt.Errorf("intermediate CA must have a parent") + } + + parentCA, err := s.caRepo.GetByID(ca.ParentID.String()) + if err != nil { + return fmt.Errorf("parent CA not found: %w", err) + } + + parentCert, _, err := s.loadCACertificateAndKey(&parentCA) + if err != nil { + return fmt.Errorf("failed to load parent CA certificate: %w", err) + } + + // Validate certificate chain + if err := cert.CheckSignatureFrom(parentCert); err != nil { + return fmt.Errorf("CA certificate is not properly signed by parent: %w", err) + } + + // Check validity period + now := time.Now() + if now.Before(cert.NotBefore) { + return fmt.Errorf("CA certificate is not yet valid") + } + if now.After(cert.NotAfter) { + return fmt.Errorf("CA certificate has expired") + } + + return nil +} + +// Helper methods + +// saveCertificate saves a certificate to a file and returns the file ID +func (s *CertificateAuthorityService) saveCertificate(cert *x509.Certificate, prefix string) (string, error) { + // Create certificate directory if it doesn't exist + if err := os.MkdirAll(s.certDir, 0755); err != nil { + return "", fmt.Errorf("failed to create certificate directory: %w", err) + } + + // Generate file ID + fileID := fmt.Sprintf("%s-%s", prefix, uuid.New().String()) + + // Export certificate to PEM + exporter := utils.NewCertificateExporter() + certPEM, err := exporter.ExportCertificateToPEM(cert) + if err != nil { + return "", fmt.Errorf("failed to export certificate to PEM: %w", err) + } + + // Save to file + certPath := filepath.Join(s.certDir, fmt.Sprintf("%s.crt", fileID)) + if err := os.WriteFile(certPath, certPEM, 0644); err != nil { + return "", fmt.Errorf("failed to write certificate file: %w", err) + } + + return fileID, nil +} + +// savePrivateKey saves a private key to a file and returns the file ID +func (s *CertificateAuthorityService) savePrivateKey(privateKey interface{}, prefix string) (string, error) { + // Create private key directory if it doesn't exist + if err := os.MkdirAll(s.privateDir, 0700); err != nil { + return "", fmt.Errorf("failed to create private key directory: %w", err) + } + + // Generate file ID + fileID := fmt.Sprintf("%s-%s", prefix, uuid.New().String()) + + // Export private key to PEM + exporter := utils.NewCertificateExporter() + keyPEM, err := exporter.ExportPrivateKeyToPEM(privateKey) + if err != nil { + return "", fmt.Errorf("failed to export private key to PEM: %w", err) + } + + // Save to file + keyPath := filepath.Join(s.privateDir, fmt.Sprintf("%s.key", fileID)) + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + return "", fmt.Errorf("failed to write private key file: %w", err) + } + + return fileID, nil +} + +// loadCACertificateAndKey loads a CA certificate and private key from files +func (s *CertificateAuthorityService) loadCACertificateAndKey(ca *models.CertificateAuthority) (*x509.Certificate, interface{}, error) { + // Load certificate + certPath := filepath.Join(s.certDir, fmt.Sprintf("%s.crt", ca.FileID)) + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to read certificate file: %w", err) + } + + block, _ := pem.Decode(certPEM) + if block == nil { + return nil, nil, fmt.Errorf("failed to decode certificate PEM") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + // Load private key + keyPath := filepath.Join(s.privateDir, fmt.Sprintf("%s.key", ca.PrivateKeyID)) + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to read private key file: %w", err) + } + + block, _ = pem.Decode(keyPEM) + if block == nil { + return nil, nil, fmt.Errorf("failed to decode private key PEM") + } + + var privateKey interface{} + switch block.Type { + case "RSA PRIVATE KEY": + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + privateKey, err = x509.ParseECPrivateKey(block.Bytes) + default: + return nil, nil, fmt.Errorf("unsupported private key type: %s", block.Type) + } + + if err != nil { + return nil, nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return cert, privateKey, nil +} + +// cleanupFiles removes certificate and private key files +func (s *CertificateAuthorityService) cleanupFiles(certFileID, privateKeyFileID string) { + if certFileID != "" { + certPath := filepath.Join(s.certDir, fmt.Sprintf("%s.crt", certFileID)) + os.Remove(certPath) + } + + if privateKeyFileID != "" { + keyPath := filepath.Join(s.privateDir, fmt.Sprintf("%s.key", privateKeyFileID)) + os.Remove(keyPath) + } +} + +// validateRootCARequest validates a root CA creation request +func (s *CertificateAuthorityService) validateRootCARequest(req *CreateRootCARequest) error { + if req.Name == "" { + return fmt.Errorf("name is required") + } + if req.CommonName == "" { + return fmt.Errorf("common name is required") + } + if req.Organization == "" { + return fmt.Errorf("organization is required") + } + if req.Country == "" { + return fmt.Errorf("country is required") + } + if req.OrganizationID == uuid.Nil { + return fmt.Errorf("organization ID is required") + } + return nil +} + +// validateIntermediateCARequest validates an intermediate CA creation request +func (s *CertificateAuthorityService) validateIntermediateCARequest(req *CreateIntermediateCARequest) error { + if req.Name == "" { + return fmt.Errorf("name is required") + } + if req.CommonName == "" { + return fmt.Errorf("common name is required") + } + if req.Organization == "" { + return fmt.Errorf("organization is required") + } + if req.Country == "" { + return fmt.Errorf("country is required") + } + if req.OrganizationID == uuid.Nil { + return fmt.Errorf("organization ID is required") + } + if req.ParentCAID == uuid.Nil { + return fmt.Errorf("parent CA ID is required") + } + return nil +} + +// Request structures + +// CreateRootCARequest represents a request to create a root CA +type CreateRootCARequest struct { + Name string + Description string + CommonName string + Organization string + OrganizationalUnit string + Country string + State string + Locality string + Street string + Address string + PostalCode string + Email string + OrganizationID uuid.UUID + ValidityYears int // Custom validity period in years (0 = use default) +} + +// CreateIntermediateCARequest represents a request to create an intermediate CA +type CreateIntermediateCARequest struct { + Name string + Description string + CommonName string + Organization string + OrganizationalUnit string + Country string + State string + Locality string + Street string + Address string + PostalCode string + Email string + OrganizationID uuid.UUID + ParentCAID uuid.UUID + ValidityYears int // Custom validity period in years (0 = use default) +} diff --git a/certificate/authority_test.go b/certificate/authority_test.go new file mode 100644 index 0000000..d745743 --- /dev/null +++ b/certificate/authority_test.go @@ -0,0 +1,690 @@ +package certificate + +import ( + "os" + "path/filepath" + "testing" + + "git.secnex.io/secnex/certman/models" + "github.com/google/uuid" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func setupTestDatabase(t *testing.T) (*gorm.DB, func()) { + // Create temporary database file + dbPath := filepath.Join(t.TempDir(), "test_certman.db") + + db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) + if err != nil { + t.Fatalf("Failed to connect to database: %v", err) + } + + // Auto-migrate models + if err := db.AutoMigrate( + &models.Organization{}, + &models.CertificateAuthority{}, + ); err != nil { + t.Fatalf("Failed to migrate database: %v", err) + } + + cleanup := func() { + os.Remove(dbPath) + } + + return db, cleanup +} + +func setupTestService(t *testing.T) (*CertificateAuthorityService, *gorm.DB, func()) { + db, dbCleanup := setupTestDatabase(t) + + certDir := filepath.Join(t.TempDir(), "certs") + privateDir := filepath.Join(t.TempDir(), "private") + + service := NewCertificateAuthorityService(db, certDir, privateDir) + + cleanup := func() { + dbCleanup() + os.RemoveAll(certDir) + os.RemoveAll(privateDir) + } + + return service, db, cleanup +} + +func createTestOrganization(t *testing.T, db *gorm.DB) *models.Organization { + org := &models.Organization{ + Name: "Test Organization", + Description: "Test Organization for CA tests", + Status: models.OrganizationStatusActive, + } + + if err := db.Create(org).Error; err != nil { + t.Fatalf("Failed to create test organization: %v", err) + } + + return org +} + +func TestCreateRootCA(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + req := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + ca, err := service.CreateRootCA(req) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + if ca == nil { + t.Fatal("Created CA should not be nil") + } + + if !ca.Root { + t.Error("Created CA should be marked as root") + } + + if ca.ParentID != nil { + t.Error("Root CA should not have a parent ID") + } + + if ca.Name != req.Name { + t.Errorf("Expected name %s, got %s", req.Name, ca.Name) + } + + if ca.AttributeCommonName != req.CommonName { + t.Errorf("Expected common name %s, got %s", req.CommonName, ca.AttributeCommonName) + } + + if ca.FileID == "" { + t.Error("Certificate file ID should not be empty") + } + + if ca.PrivateKeyID == "" { + t.Error("Private key file ID should not be empty") + } +} + +func TestCreateIntermediateCA(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // First create a root CA + rootReq := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + rootCA, err := service.CreateRootCA(rootReq) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + // Now create an intermediate CA + req := &CreateIntermediateCARequest{ + Name: "Test Intermediate CA", + Description: "Test Intermediate CA for testing", + CommonName: "Test Intermediate CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "intermediate@example.com", + OrganizationID: org.ID, + ParentCAID: rootCA.ID, + } + + ca, err := service.CreateIntermediateCA(req) + if err != nil { + t.Fatalf("Failed to create intermediate CA: %v", err) + } + + if ca == nil { + t.Fatal("Created CA should not be nil") + } + + if ca.Root { + t.Error("Created CA should not be marked as root") + } + + if ca.ParentID == nil { + t.Error("Intermediate CA should have a parent ID") + } + + if ca.ParentID.String() != rootCA.ID.String() { + t.Errorf("Expected parent ID %s, got %s", rootCA.ID.String(), ca.ParentID.String()) + } + + if ca.Name != req.Name { + t.Errorf("Expected name %s, got %s", req.Name, ca.Name) + } + + if ca.AttributeCommonName != req.CommonName { + t.Errorf("Expected common name %s, got %s", req.CommonName, ca.AttributeCommonName) + } + + if ca.FileID == "" { + t.Error("Certificate file ID should not be empty") + } + + if ca.PrivateKeyID == "" { + t.Error("Private key file ID should not be empty") + } +} + +func TestGetCA(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // Create a root CA + req := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + createdCA, err := service.CreateRootCA(req) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + // Retrieve the CA + retrievedCA, err := service.GetCA(createdCA.ID.String()) + if err != nil { + t.Fatalf("Failed to get CA: %v", err) + } + + if retrievedCA.ID != createdCA.ID { + t.Errorf("Expected ID %s, got %s", createdCA.ID.String(), retrievedCA.ID.String()) + } + + if retrievedCA.Name != createdCA.Name { + t.Errorf("Expected name %s, got %s", createdCA.Name, retrievedCA.Name) + } +} + +func TestGetAllCAs(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // Create multiple CAs + rootReq := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + rootCA, err := service.CreateRootCA(rootReq) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + intermediateReq := &CreateIntermediateCARequest{ + Name: "Test Intermediate CA", + Description: "Test Intermediate CA for testing", + CommonName: "Test Intermediate CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "intermediate@example.com", + OrganizationID: org.ID, + ParentCAID: rootCA.ID, + } + + _, err = service.CreateIntermediateCA(intermediateReq) + if err != nil { + t.Fatalf("Failed to create intermediate CA: %v", err) + } + + // Get all CAs + allCAs, err := service.GetAllCAs() + if err != nil { + t.Fatalf("Failed to get all CAs: %v", err) + } + + if len(allCAs) != 2 { + t.Errorf("Expected 2 CAs, got %d", len(allCAs)) + } + + // Get root CAs + rootCAs, err := service.GetRootCAs() + if err != nil { + t.Fatalf("Failed to get root CAs: %v", err) + } + + if len(rootCAs) != 1 { + t.Errorf("Expected 1 root CA, got %d", len(rootCAs)) + } + + // Get intermediate CAs + intermediateCAs, err := service.GetIntermediateCAs() + if err != nil { + t.Fatalf("Failed to get intermediate CAs: %v", err) + } + + if len(intermediateCAs) != 1 { + t.Errorf("Expected 1 intermediate CA, got %d", len(intermediateCAs)) + } +} + +func TestGetCACertificate(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // Create a root CA + req := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + ca, err := service.CreateRootCA(req) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + // Get CA certificate + cert, err := service.GetCACertificate(ca.ID.String()) + if err != nil { + t.Fatalf("Failed to get CA certificate: %v", err) + } + + if cert == nil { + t.Fatal("Certificate should not be nil") + } + + if cert.Subject.CommonName != req.CommonName { + t.Errorf("Expected common name %s, got %s", req.CommonName, cert.Subject.CommonName) + } + + if !cert.IsCA { + t.Error("Certificate should be marked as CA") + } +} + +func TestGetCAPrivateKey(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // Create a root CA + req := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + ca, err := service.CreateRootCA(req) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + // Get CA private key + privateKey, err := service.GetCAPrivateKey(ca.ID.String()) + if err != nil { + t.Fatalf("Failed to get CA private key: %v", err) + } + + if privateKey == nil { + t.Fatal("Private key should not be nil") + } +} + +func TestValidateCAChain(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // Create a root CA + rootReq := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + rootCA, err := service.CreateRootCA(rootReq) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + // Validate root CA chain + if err := service.ValidateCAChain(rootCA.ID.String()); err != nil { + t.Fatalf("Failed to validate root CA chain: %v", err) + } + + // Create an intermediate CA + intermediateReq := &CreateIntermediateCARequest{ + Name: "Test Intermediate CA", + Description: "Test Intermediate CA for testing", + CommonName: "Test Intermediate CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "intermediate@example.com", + OrganizationID: org.ID, + ParentCAID: rootCA.ID, + } + + intermediateCA, err := service.CreateIntermediateCA(intermediateReq) + if err != nil { + t.Fatalf("Failed to create intermediate CA: %v", err) + } + + // Validate intermediate CA chain + if err := service.ValidateCAChain(intermediateCA.ID.String()); err != nil { + t.Fatalf("Failed to validate intermediate CA chain: %v", err) + } +} + +func TestRevokeCA(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // Create a root CA + req := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + ca, err := service.CreateRootCA(req) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + // Revoke CA + reason := "Test revocation" + if err := service.RevokeCA(ca.ID.String(), reason); err != nil { + t.Fatalf("Failed to revoke CA: %v", err) + } + + // Verify CA was updated + updatedCA, err := service.GetCA(ca.ID.String()) + if err != nil { + t.Fatalf("Failed to get updated CA: %v", err) + } + + if updatedCA.Description == ca.Description { + t.Error("CA description should have been updated with revocation info") + } +} + +func TestDeleteCA(t *testing.T) { + service, db, cleanup := setupTestService(t) + defer cleanup() + + org := createTestOrganization(t, db) + + // Create a root CA + req := &CreateRootCARequest{ + Name: "Test Root CA", + Description: "Test Root CA for testing", + CommonName: "Test Root CA", + Organization: "Test Organization", + OrganizationalUnit: "IT Department", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street 1", + Address: "Test Street 1, 80331 Munich", + PostalCode: "80331", + Email: "test@example.com", + OrganizationID: org.ID, + } + + ca, err := service.CreateRootCA(req) + if err != nil { + t.Fatalf("Failed to create root CA: %v", err) + } + + // Delete CA + if err := service.DeleteCA(ca.ID.String()); err != nil { + t.Fatalf("Failed to delete CA: %v", err) + } + + // Verify CA was deleted + _, err = service.GetCA(ca.ID.String()) + if err == nil { + t.Error("CA should have been deleted") + } +} + +func TestValidateRootCARequest(t *testing.T) { + service, _, cleanup := setupTestService(t) + defer cleanup() + + tests := []struct { + name string + req *CreateRootCARequest + wantErr bool + }{ + { + name: "valid request", + req: &CreateRootCARequest{ + Name: "Test CA", + CommonName: "Test CA", + Organization: "Test Org", + Country: "DE", + OrganizationID: uuid.New(), + }, + wantErr: false, + }, + { + name: "missing name", + req: &CreateRootCARequest{ + CommonName: "Test CA", + Organization: "Test Org", + Country: "DE", + OrganizationID: uuid.New(), + }, + wantErr: true, + }, + { + name: "missing common name", + req: &CreateRootCARequest{ + Name: "Test CA", + Organization: "Test Org", + Country: "DE", + OrganizationID: uuid.New(), + }, + wantErr: true, + }, + { + name: "missing organization", + req: &CreateRootCARequest{ + Name: "Test CA", + CommonName: "Test CA", + Country: "DE", + OrganizationID: uuid.New(), + }, + wantErr: true, + }, + { + name: "missing country", + req: &CreateRootCARequest{ + Name: "Test CA", + CommonName: "Test CA", + Organization: "Test Org", + OrganizationID: uuid.New(), + }, + wantErr: true, + }, + { + name: "missing organization ID", + req: &CreateRootCARequest{ + Name: "Test CA", + CommonName: "Test CA", + Organization: "Test Org", + Country: "DE", + }, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := service.validateRootCARequest(test.req) + if (err != nil) != test.wantErr { + t.Errorf("validateRootCARequest() error = %v, wantErr %v", err, test.wantErr) + } + }) + } +} + +func TestValidateIntermediateCARequest(t *testing.T) { + service, _, cleanup := setupTestService(t) + defer cleanup() + + tests := []struct { + name string + req *CreateIntermediateCARequest + wantErr bool + }{ + { + name: "valid request", + req: &CreateIntermediateCARequest{ + Name: "Test CA", + CommonName: "Test CA", + Organization: "Test Org", + Country: "DE", + OrganizationID: uuid.New(), + ParentCAID: uuid.New(), + }, + wantErr: false, + }, + { + name: "missing parent CA ID", + req: &CreateIntermediateCARequest{ + Name: "Test CA", + CommonName: "Test CA", + Organization: "Test Org", + Country: "DE", + OrganizationID: uuid.New(), + }, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := service.validateIntermediateCARequest(test.req) + if (err != nil) != test.wantErr { + t.Errorf("validateIntermediateCARequest() error = %v, wantErr %v", err, test.wantErr) + } + }) + } +}