From d9726d9204528f1b098eff29d4f4fc8f41d2adb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Benouarets?= Date: Tue, 30 Sep 2025 11:45:05 +0200 Subject: [PATCH] feat: add certificate service and utility functions - Implement CertificateService for end-to-end certificate management - Add support for various certificate types (web, client, email, etc.) - Include certificate generation, validation, and revocation - Add utility functions for certificate operations and validation - Implement comprehensive test coverage for certificate operations - Support SAN (Subject Alternative Name) and IP address extensions - Add proper error handling and validation for certificate requests --- certificate/certificate.go | 835 ++++++++++++++++++++++++++ certificate/certificate_test.go | 403 +++++++++++++ certificate/utils/certificate.go | 783 ++++++++++++++++++++++++ certificate/utils/certificate_test.go | 537 +++++++++++++++++ 4 files changed, 2558 insertions(+) create mode 100644 certificate/certificate.go create mode 100644 certificate/certificate_test.go create mode 100644 certificate/utils/certificate.go create mode 100644 certificate/utils/certificate_test.go diff --git a/certificate/certificate.go b/certificate/certificate.go new file mode 100644 index 0000000..b20dc04 --- /dev/null +++ b/certificate/certificate.go @@ -0,0 +1,835 @@ +package certificate + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "net" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "git.secnex.io/secnex/certman/certificate/utils" + "git.secnex.io/secnex/certman/models" + "git.secnex.io/secnex/certman/repositories" + "github.com/google/uuid" + "gorm.io/gorm" +) + +// CertificateService handles certificate operations +type CertificateService struct { + certRepo *repositories.CertificateRepository + caRepo *repositories.CertificateAuthorityRepository + csrRepo *repositories.CertificateRequestRepository + certDir string + privateDir string + caService *CertificateAuthorityService +} + +// NewCertificateService creates a new certificate service +func NewCertificateService( + db *gorm.DB, + certDir string, + privateDir string, + caService *CertificateAuthorityService, +) *CertificateService { + return &CertificateService{ + certRepo: repositories.NewCertificateRepository(db), + caRepo: repositories.NewCertificateAuthorityRepository(db), + csrRepo: repositories.NewCertificateRequestRepository(db), + certDir: certDir, + privateDir: privateDir, + caService: caService, + } +} + +// CreateCertificate creates a new certificate from a request +func (s *CertificateService) CreateCertificate(req *CreateCertificateRequest) (*models.Certificate, error) { + // Validate request + if err := s.validateCreateCertificateRequest(req); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + // Get CA + ca, err := s.caRepo.GetByID(req.CertificateAuthorityID.String()) + if err != nil { + return nil, fmt.Errorf("CA not found: %w", err) + } + + // Load CA certificate and private key + caCert, caPrivateKey, err := s.caService.loadCACertificateAndKey(&ca) + if err != nil { + return nil, fmt.Errorf("failed to load CA certificate: %w", err) + } + + // Create certificate configuration (iOS-compatible by default for web certificates) + config := s.createCertificateConfig(req) + + // Ensure proper certificate configuration + config.SubjectKeyID = true + config.AuthorityKeyID = true + + // Validate configuration + if err := utils.ValidateCertificateConfig(config); err != nil { + return nil, fmt.Errorf("invalid certificate configuration: %w", err) + } + + // Generate certificate and private key + generator := utils.NewCertificateGenerator(config) + cert, privateKey, err := generator.GenerateCertificate(caCert, caPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to generate certificate: %w", err) + } + + // Save certificate and private key to files + certFileID, err := s.saveCertificate(cert, req.Type) + if err != nil { + return nil, fmt.Errorf("failed to save certificate: %w", err) + } + + privateKeyFileID, err := s.savePrivateKey(privateKey, req.Type) + if err != nil { + return nil, fmt.Errorf("failed to save private key: %w", err) + } + + // Create certificate model + certModel := &models.Certificate{ + Name: req.Name, + Description: req.Description, + SerialNumber: cert.SerialNumber.String(), + AttributeCommonName: cert.Subject.CommonName, + AttributeSubjectAlternativeName: s.formatSANs(cert.DNSNames, cert.IPAddresses, cert.EmailAddresses, cert.URIs), + AttributeOrganization: cert.Subject.Organization[0], + AttributeOrganizationUnit: cert.Subject.OrganizationalUnit[0], + AttributeCountry: cert.Subject.Country[0], + AttributeState: cert.Subject.Province[0], + AttributeLocality: cert.Subject.Locality[0], + AttributeStreet: cert.Subject.StreetAddress[0], + AttributeEmail: req.Email, + AttributeAddress: req.Address, + AttributePostalCode: cert.Subject.PostalCode[0], + AttributeNotBefore: cert.NotBefore, + AttributeNotAfter: cert.NotAfter, + Type: req.Type, + Status: models.CertificateStatusActive, + CertificateAuthorityID: req.CertificateAuthorityID, + RequestID: req.RequestID, + FileID: certFileID, + PrivateKeyID: privateKeyFileID, + Generated: true, + GeneratedAt: &time.Time{}, + } + + // Set generated time + now := time.Now() + certModel.GeneratedAt = &now + + // Save to database + if err := s.certRepo.Create(*certModel); err != nil { + // Clean up files if database save fails + s.cleanupFiles(certFileID, privateKeyFileID) + return nil, fmt.Errorf("failed to save certificate to database: %w", err) + } + + return certModel, nil +} + +// CreateCertificateFromCSR creates a certificate from a Certificate Signing Request +func (s *CertificateService) CreateCertificateFromCSR(req *CreateCertificateFromCSRRequest) (*models.Certificate, error) { + // Validate request + if err := s.validateCreateCertificateFromCSRRequest(req); err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + // Get CSR + csr, err := s.csrRepo.GetByID(req.CSRID.String()) + if err != nil { + return nil, fmt.Errorf("CSR not found: %w", err) + } + + // Get CA + ca, err := s.caRepo.GetByID(req.CertificateAuthorityID.String()) + if err != nil { + return nil, fmt.Errorf("CA not found: %w", err) + } + + // Load CA certificate and private key + caCert, caPrivateKey, err := s.caService.loadCACertificateAndKey(&ca) + if err != nil { + return nil, fmt.Errorf("failed to load CA certificate: %w", err) + } + + // Parse CSR + csrParsed, err := x509.ParseCertificateRequest(csr.CSRData) + if err != nil { + return nil, fmt.Errorf("failed to parse CSR: %w", err) + } + + // Create certificate configuration from CSR + config := s.createCertificateConfigFromCSR(req, csrParsed) + + // Validate configuration + if err := utils.ValidateCertificateConfig(config); err != nil { + return nil, fmt.Errorf("invalid certificate configuration: %w", err) + } + + // Generate certificate + generator := utils.NewCertificateGenerator(config) + cert, privateKey, err := generator.GenerateCertificate(caCert, caPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to generate certificate: %w", err) + } + + // Save certificate and private key to files + certFileID, err := s.saveCertificate(cert, req.Type) + if err != nil { + return nil, fmt.Errorf("failed to save certificate: %w", err) + } + + privateKeyFileID, err := s.savePrivateKey(privateKey, req.Type) + if err != nil { + return nil, fmt.Errorf("failed to save private key: %w", err) + } + + // Create certificate model + certModel := &models.Certificate{ + Name: req.Name, + Description: req.Description, + SerialNumber: cert.SerialNumber.String(), + AttributeCommonName: cert.Subject.CommonName, + AttributeSubjectAlternativeName: s.formatSANs(cert.DNSNames, cert.IPAddresses, cert.EmailAddresses, cert.URIs), + AttributeOrganization: cert.Subject.Organization[0], + AttributeOrganizationUnit: cert.Subject.OrganizationalUnit[0], + AttributeCountry: cert.Subject.Country[0], + AttributeState: cert.Subject.Province[0], + AttributeLocality: cert.Subject.Locality[0], + AttributeStreet: cert.Subject.StreetAddress[0], + AttributeEmail: "", + AttributeAddress: "", + AttributePostalCode: cert.Subject.PostalCode[0], + AttributeNotBefore: cert.NotBefore, + AttributeNotAfter: cert.NotAfter, + Type: req.Type, + Status: models.CertificateStatusActive, + CertificateAuthorityID: req.CertificateAuthorityID, + RequestID: &req.CSRID, + FileID: certFileID, + PrivateKeyID: privateKeyFileID, + Generated: true, + GeneratedAt: &time.Time{}, + } + + // Set generated time + now := time.Now() + certModel.GeneratedAt = &now + + // Save to database + if err := s.certRepo.Create(*certModel); err != nil { + // Clean up files if database save fails + s.cleanupFiles(certFileID, privateKeyFileID) + return nil, fmt.Errorf("failed to save certificate to database: %w", err) + } + + return certModel, nil +} + +// GetCertificate retrieves a certificate by ID +func (s *CertificateService) GetCertificate(id string) (*models.Certificate, error) { + cert, err := s.certRepo.GetByID(id) + if err != nil { + return nil, fmt.Errorf("certificate not found: %w", err) + } + return &cert, nil +} + +// GetAllCertificates retrieves all certificates +func (s *CertificateService) GetAllCertificates() ([]models.Certificate, error) { + certs, err := s.certRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to retrieve certificates: %w", err) + } + return certs, nil +} + +// GetCertificatesByType retrieves certificates by type +func (s *CertificateService) GetCertificatesByType(certType models.CertificateType) ([]models.Certificate, error) { + certs, err := s.certRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to retrieve certificates: %w", err) + } + + var filteredCerts []models.Certificate + for _, cert := range certs { + if cert.Type == certType { + filteredCerts = append(filteredCerts, cert) + } + } + + return filteredCerts, nil +} + +// GetCertificatesByCA retrieves certificates issued by a specific CA +func (s *CertificateService) GetCertificatesByCA(caID string) ([]models.Certificate, error) { + certs, err := s.certRepo.GetAll() + if err != nil { + return nil, fmt.Errorf("failed to retrieve certificates: %w", err) + } + + var filteredCerts []models.Certificate + for _, cert := range certs { + if cert.CertificateAuthorityID.String() == caID { + filteredCerts = append(filteredCerts, cert) + } + } + + return filteredCerts, nil +} + +// GetCertificateFile retrieves the certificate file +func (s *CertificateService) GetCertificateFile(certID string) (*x509.Certificate, error) { + cert, err := s.certRepo.GetByID(certID) + if err != nil { + return nil, fmt.Errorf("certificate not found: %w", err) + } + + certPath := filepath.Join(s.certDir, fmt.Sprintf("%s.crt", cert.FileID)) + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("failed to read certificate file: %w", err) + } + + block, _ := pem.Decode(certPEM) + if block == nil { + return nil, fmt.Errorf("failed to decode certificate PEM") + } + + certParsed, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return certParsed, nil +} + +// GetPrivateKeyFile retrieves the private key file +func (s *CertificateService) GetPrivateKeyFile(certID string) (interface{}, error) { + cert, err := s.certRepo.GetByID(certID) + if err != nil { + return nil, fmt.Errorf("certificate not found: %w", err) + } + + keyPath := filepath.Join(s.privateDir, fmt.Sprintf("%s.key", cert.PrivateKeyID)) + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read private key file: %w", err) + } + + block, _ := pem.Decode(keyPEM) + if block == nil { + return nil, fmt.Errorf("failed to decode private key PEM") + } + + var privateKey interface{} + switch block.Type { + case "RSA PRIVATE KEY": + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + privateKey, err = x509.ParseECPrivateKey(block.Bytes) + default: + return nil, fmt.Errorf("unsupported private key type: %s", block.Type) + } + + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return privateKey, nil +} + +// UpdateCertificate updates a certificate +func (s *CertificateService) UpdateCertificate(cert *models.Certificate) error { + return s.certRepo.Update(*cert) +} + +// RevokeCertificate revokes a certificate +func (s *CertificateService) RevokeCertificate(certID string, reason string) error { + cert, err := s.certRepo.GetByID(certID) + if err != nil { + return fmt.Errorf("certificate not found: %w", err) + } + + // Update certificate with revocation information + cert.Status = models.CertificateStatusRevoked + cert.RevocationReason = reason + now := time.Now() + cert.RevokedAt = &now + + if err := s.certRepo.Update(cert); err != nil { + return fmt.Errorf("failed to revoke certificate: %w", err) + } + + return nil +} + +// DeleteCertificate deletes a certificate +func (s *CertificateService) DeleteCertificate(certID string) error { + cert, err := s.certRepo.GetByID(certID) + if err != nil { + return fmt.Errorf("certificate not found: %w", err) + } + + // Delete from database + if err := s.certRepo.Delete(certID); err != nil { + return fmt.Errorf("failed to delete certificate from database: %w", err) + } + + // Clean up files + s.cleanupFiles(cert.FileID, cert.PrivateKeyID) + + return nil +} + +// ValidateCertificate validates a certificate +func (s *CertificateService) ValidateCertificate(certID string) error { + cert, err := s.certRepo.GetByID(certID) + if err != nil { + return fmt.Errorf("certificate not found: %w", err) + } + + // Load certificate + certParsed, err := s.GetCertificateFile(certID) + if err != nil { + return fmt.Errorf("failed to load certificate: %w", err) + } + + // Check validity period + now := time.Now() + if now.Before(certParsed.NotBefore) { + return fmt.Errorf("certificate is not yet valid") + } + if now.After(certParsed.NotAfter) { + return fmt.Errorf("certificate has expired") + } + + // Check status + if cert.Status == models.CertificateStatusRevoked { + return fmt.Errorf("certificate has been revoked: %s", cert.RevocationReason) + } + + return nil +} + +// ExportCertificate exports a certificate in various formats +func (s *CertificateService) ExportCertificate(certID string, format string) ([]byte, error) { + certParsed, err := s.GetCertificateFile(certID) + if err != nil { + return nil, fmt.Errorf("failed to load certificate: %w", err) + } + + exporter := utils.NewCertificateExporter() + + switch format { + case "pem": + return exporter.ExportCertificateToPEM(certParsed) + case "der": + return exporter.ExportCertificateToDER(certParsed), nil + default: + return nil, fmt.Errorf("unsupported export format: %s", format) + } +} + +// ExportPrivateKey exports a private key in various formats +func (s *CertificateService) ExportPrivateKey(certID string, format string) ([]byte, error) { + privateKey, err := s.GetPrivateKeyFile(certID) + if err != nil { + return nil, fmt.Errorf("failed to load private key: %w", err) + } + + exporter := utils.NewCertificateExporter() + + switch format { + case "pem": + return exporter.ExportPrivateKeyToPEM(privateKey) + case "der": + return exporter.ExportPrivateKeyToDER(privateKey) + default: + return nil, fmt.Errorf("unsupported export format: %s", format) + } +} + +// Helper methods + +// createCertificateConfig creates a certificate configuration from a request +func (s *CertificateService) createCertificateConfig(req *CreateCertificateRequest) *utils.CertificateConfig { + config := utils.DefaultCertificateConfig(req.Type) + + // Set basic information + config.CommonName = req.CommonName + config.Organization = req.Organization + config.OrganizationalUnit = req.OrganizationalUnit + config.Country = req.Country + config.State = req.State + config.Locality = req.Locality + config.Street = req.Street + config.PostalCode = req.PostalCode + config.Email = req.Email + + // Set validity period + if req.NotBefore != nil { + config.NotBefore = *req.NotBefore + } + if req.NotAfter != nil { + config.NotAfter = *req.NotAfter + } + if req.ValidityYears > 0 { + config.ValidityYears = req.ValidityYears + config.NotAfter = config.NotBefore.AddDate(req.ValidityYears, 0, 0) + } + + // Set SANs + config.DNSNames = req.DNSNames + config.IPAddresses = req.IPAddresses + config.EmailAddresses = req.EmailAddresses + config.URIs = req.URIs + + // Set key configuration + if req.KeyType != "" { + config.KeyType = utils.KeyType(req.KeyType) + } + if req.KeySize != 0 { + config.KeySize = utils.KeySize(req.KeySize) + } + if req.Curve != "" { + config.Curve = utils.Curve(req.Curve) + } + + // Set custom key usage and extended key usage + if req.KeyUsage != nil { + config.DigitalSignature = req.KeyUsage.DigitalSignature + config.ContentCommitment = req.KeyUsage.ContentCommitment + config.KeyEncipherment = req.KeyUsage.KeyEncipherment + config.DataEncipherment = req.KeyUsage.DataEncipherment + config.KeyAgreement = req.KeyUsage.KeyAgreement + config.KeyCertSign = req.KeyUsage.KeyCertSign + config.CRLSign = req.KeyUsage.CRLSign + config.EncipherOnly = req.KeyUsage.EncipherOnly + config.DecipherOnly = req.KeyUsage.DecipherOnly + } + + if req.ExtendedKeyUsage != nil { + config.ServerAuth = req.ExtendedKeyUsage.ServerAuth + config.ClientAuth = req.ExtendedKeyUsage.ClientAuth + config.CodeSigning = req.ExtendedKeyUsage.CodeSigning + config.EmailProtection = req.ExtendedKeyUsage.EmailProtection + config.TimeStamping = req.ExtendedKeyUsage.TimeStamping + config.OCSPSigning = req.ExtendedKeyUsage.OCSPSigning + } + + return config +} + +// createCertificateConfigFromCSR creates a certificate configuration from a CSR +func (s *CertificateService) createCertificateConfigFromCSR(req *CreateCertificateFromCSRRequest, csr *x509.CertificateRequest) *utils.CertificateConfig { + config := utils.DefaultCertificateConfig(req.Type) + + // Set basic information from CSR + config.CommonName = csr.Subject.CommonName + if len(csr.Subject.Organization) > 0 { + config.Organization = csr.Subject.Organization[0] + } + if len(csr.Subject.OrganizationalUnit) > 0 { + config.OrganizationalUnit = csr.Subject.OrganizationalUnit[0] + } + if len(csr.Subject.Country) > 0 { + config.Country = csr.Subject.Country[0] + } + if len(csr.Subject.Province) > 0 { + config.State = csr.Subject.Province[0] + } + if len(csr.Subject.Locality) > 0 { + config.Locality = csr.Subject.Locality[0] + } + if len(csr.Subject.StreetAddress) > 0 { + config.Street = csr.Subject.StreetAddress[0] + } + if len(csr.Subject.PostalCode) > 0 { + config.PostalCode = csr.Subject.PostalCode[0] + } + + // Set validity period + if req.NotBefore != nil { + config.NotBefore = *req.NotBefore + } + if req.NotAfter != nil { + config.NotAfter = *req.NotAfter + } + if req.ValidityYears > 0 { + config.ValidityYears = req.ValidityYears + config.NotAfter = config.NotBefore.AddDate(req.ValidityYears, 0, 0) + } + + // Set SANs from CSR + config.DNSNames = csr.DNSNames + config.IPAddresses = csr.IPAddresses + config.EmailAddresses = csr.EmailAddresses + config.URIs = s.parseURIsFromCSR(csr.URIs) + + // Set key configuration + if req.KeyType != "" { + config.KeyType = utils.KeyType(req.KeyType) + } + if req.KeySize != 0 { + config.KeySize = utils.KeySize(req.KeySize) + } + if req.Curve != "" { + config.Curve = utils.Curve(req.Curve) + } + + // Set custom key usage and extended key usage + if req.KeyUsage != nil { + config.DigitalSignature = req.KeyUsage.DigitalSignature + config.ContentCommitment = req.KeyUsage.ContentCommitment + config.KeyEncipherment = req.KeyUsage.KeyEncipherment + config.DataEncipherment = req.KeyUsage.DataEncipherment + config.KeyAgreement = req.KeyUsage.KeyAgreement + config.KeyCertSign = req.KeyUsage.KeyCertSign + config.CRLSign = req.KeyUsage.CRLSign + config.EncipherOnly = req.KeyUsage.EncipherOnly + config.DecipherOnly = req.KeyUsage.DecipherOnly + } + + if req.ExtendedKeyUsage != nil { + config.ServerAuth = req.ExtendedKeyUsage.ServerAuth + config.ClientAuth = req.ExtendedKeyUsage.ClientAuth + config.CodeSigning = req.ExtendedKeyUsage.CodeSigning + config.EmailProtection = req.ExtendedKeyUsage.EmailProtection + config.TimeStamping = req.ExtendedKeyUsage.TimeStamping + config.OCSPSigning = req.ExtendedKeyUsage.OCSPSigning + } + + return config +} + +// formatSANs formats SANs for storage +func (s *CertificateService) formatSANs(dnsNames []string, ipAddresses []net.IP, emailAddresses []string, uris []*url.URL) string { + var sans []string + + for _, dns := range dnsNames { + sans = append(sans, dns) + } + + for _, ip := range ipAddresses { + sans = append(sans, ip.String()) + } + + for _, email := range emailAddresses { + sans = append(sans, email) + } + + for _, uri := range uris { + sans = append(sans, uri.String()) + } + + return strings.Join(sans, ",") +} + +// parseURIsFromCSR parses URIs from CSR +func (s *CertificateService) parseURIsFromCSR(uris []*url.URL) []string { + var uriStrings []string + for _, uri := range uris { + uriStrings = append(uriStrings, uri.String()) + } + return uriStrings +} + +// saveCertificate saves a certificate to a file and returns the file ID +func (s *CertificateService) saveCertificate(cert *x509.Certificate, certType models.CertificateType) (string, error) { + // Create certificate directory if it doesn't exist + if err := os.MkdirAll(s.certDir, 0755); err != nil { + return "", fmt.Errorf("failed to create certificate directory: %w", err) + } + + // Generate file ID + fileID := fmt.Sprintf("%s-%s", certType, uuid.New().String()) + + // Export certificate to PEM + exporter := utils.NewCertificateExporter() + certPEM, err := exporter.ExportCertificateToPEM(cert) + if err != nil { + return "", fmt.Errorf("failed to export certificate to PEM: %w", err) + } + + // Save to file + certPath := filepath.Join(s.certDir, fmt.Sprintf("%s.crt", fileID)) + if err := os.WriteFile(certPath, certPEM, 0644); err != nil { + return "", fmt.Errorf("failed to write certificate file: %w", err) + } + + return fileID, nil +} + +// savePrivateKey saves a private key to a file and returns the file ID +func (s *CertificateService) savePrivateKey(privateKey interface{}, certType models.CertificateType) (string, error) { + // Create private key directory if it doesn't exist + if err := os.MkdirAll(s.privateDir, 0700); err != nil { + return "", fmt.Errorf("failed to create private key directory: %w", err) + } + + // Generate file ID + fileID := fmt.Sprintf("%s-%s", certType, uuid.New().String()) + + // Export private key to PEM + exporter := utils.NewCertificateExporter() + keyPEM, err := exporter.ExportPrivateKeyToPEM(privateKey) + if err != nil { + return "", fmt.Errorf("failed to export private key to PEM: %w", err) + } + + // Save to file + keyPath := filepath.Join(s.privateDir, fmt.Sprintf("%s.key", fileID)) + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + return "", fmt.Errorf("failed to write private key file: %w", err) + } + + return fileID, nil +} + +// cleanupFiles removes certificate and private key files +func (s *CertificateService) cleanupFiles(certFileID, privateKeyFileID string) { + if certFileID != "" { + certPath := filepath.Join(s.certDir, fmt.Sprintf("%s.crt", certFileID)) + os.Remove(certPath) + } + + if privateKeyFileID != "" { + keyPath := filepath.Join(s.privateDir, fmt.Sprintf("%s.key", privateKeyFileID)) + os.Remove(keyPath) + } +} + +// validateCreateCertificateRequest validates a certificate creation request +func (s *CertificateService) validateCreateCertificateRequest(req *CreateCertificateRequest) error { + if req.Name == "" { + return fmt.Errorf("name is required") + } + if req.CommonName == "" { + return fmt.Errorf("common name is required") + } + if req.Organization == "" { + return fmt.Errorf("organization is required") + } + if req.Country == "" { + return fmt.Errorf("country is required") + } + if req.CertificateAuthorityID == uuid.Nil { + return fmt.Errorf("certificate authority ID is required") + } + return nil +} + +// validateCreateCertificateFromCSRRequest validates a certificate creation from CSR request +func (s *CertificateService) validateCreateCertificateFromCSRRequest(req *CreateCertificateFromCSRRequest) error { + if req.Name == "" { + return fmt.Errorf("name is required") + } + if req.CSRID == uuid.Nil { + return fmt.Errorf("CSR ID is required") + } + if req.CertificateAuthorityID == uuid.Nil { + return fmt.Errorf("certificate authority ID is required") + } + return nil +} + +// Request structures + +// CreateCertificateRequest represents a request to create a certificate +type CreateCertificateRequest struct { + Name string + Description string + CommonName string + Organization string + OrganizationalUnit string + Country string + State string + Locality string + Street string + Address string + PostalCode string + Email string + Type models.CertificateType + CertificateAuthorityID uuid.UUID + RequestID *uuid.UUID + + // Validity period + NotBefore *time.Time + NotAfter *time.Time + ValidityYears int // Custom validity period in years (overrides NotAfter if set) + + // Subject Alternative Names + DNSNames []string + IPAddresses []net.IP + EmailAddresses []string + URIs []string + + // Key configuration + KeyType string // "rsa" or "ecdsa" + KeySize int // RSA key size (2048, 3072, 4096) + Curve string // ECDSA curve ("P-256", "P-384", "P-521") + + // Custom key usage + KeyUsage *KeyUsageConfig + + // Custom extended key usage + ExtendedKeyUsage *ExtendedKeyUsageConfig +} + +// CreateCertificateFromCSRRequest represents a request to create a certificate from a CSR +type CreateCertificateFromCSRRequest struct { + Name string + Description string + Type models.CertificateType + CSRID uuid.UUID + CertificateAuthorityID uuid.UUID + + // Validity period + NotBefore *time.Time + NotAfter *time.Time + ValidityYears int // Custom validity period in years (overrides NotAfter if set) + + // Key configuration + KeyType string // "rsa" or "ecdsa" + KeySize int // RSA key size (2048, 3072, 4096) + Curve string // ECDSA curve ("P-256", "P-384", "P-521") + + // Custom key usage + KeyUsage *KeyUsageConfig + + // Custom extended key usage + ExtendedKeyUsage *ExtendedKeyUsageConfig +} + +// KeyUsageConfig represents key usage configuration +type KeyUsageConfig struct { + DigitalSignature bool + ContentCommitment bool + KeyEncipherment bool + DataEncipherment bool + KeyAgreement bool + KeyCertSign bool + CRLSign bool + EncipherOnly bool + DecipherOnly bool +} + +// ExtendedKeyUsageConfig represents extended key usage configuration +type ExtendedKeyUsageConfig struct { + ServerAuth bool + ClientAuth bool + CodeSigning bool + EmailProtection bool + TimeStamping bool + OCSPSigning bool +} diff --git a/certificate/certificate_test.go b/certificate/certificate_test.go new file mode 100644 index 0000000..e8d5126 --- /dev/null +++ b/certificate/certificate_test.go @@ -0,0 +1,403 @@ +package certificate + +import ( + "net" + "testing" + "time" + + "git.secnex.io/secnex/certman/models" + "github.com/google/uuid" +) + +func TestCreateCertificateRequest(t *testing.T) { + service := &CertificateService{} + + tests := []struct { + name string + req *CreateCertificateRequest + wantErr bool + }{ + { + name: "valid web certificate request", + req: &CreateCertificateRequest{ + Name: "Test Web Certificate", + CommonName: "example.com", + Organization: "Test Org", + Country: "DE", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: false, + }, + { + name: "valid client certificate request", + req: &CreateCertificateRequest{ + Name: "Test Client Certificate", + CommonName: "user@example.com", + Organization: "Test Org", + Country: "DE", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeClient, + }, + wantErr: false, + }, + { + name: "valid IoT certificate request", + req: &CreateCertificateRequest{ + Name: "Test IoT Certificate", + CommonName: "device.example.com", + Organization: "Test Org", + Country: "DE", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeIoT, + }, + wantErr: false, + }, + { + name: "missing name", + req: &CreateCertificateRequest{ + CommonName: "example.com", + Organization: "Test Org", + Country: "DE", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + { + name: "missing common name", + req: &CreateCertificateRequest{ + Name: "Test Certificate", + Organization: "Test Org", + Country: "DE", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + { + name: "missing organization", + req: &CreateCertificateRequest{ + Name: "Test Certificate", + CommonName: "example.com", + Country: "DE", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + { + name: "missing country", + req: &CreateCertificateRequest{ + Name: "Test Certificate", + CommonName: "example.com", + Organization: "Test Org", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + { + name: "missing CA ID", + req: &CreateCertificateRequest{ + Name: "Test Certificate", + CommonName: "example.com", + Organization: "Test Org", + Country: "DE", + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := service.validateCreateCertificateRequest(test.req) + if (err != nil) != test.wantErr { + t.Errorf("validateCreateCertificateRequest() error = %v, wantErr %v", err, test.wantErr) + } + }) + } +} + +func TestCreateCertificateFromCSRRequest(t *testing.T) { + service := &CertificateService{} + + tests := []struct { + name string + req *CreateCertificateFromCSRRequest + wantErr bool + }{ + { + name: "valid CSR request", + req: &CreateCertificateFromCSRRequest{ + Name: "Test Certificate from CSR", + CSRID: uuid.New(), + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: false, + }, + { + name: "missing name", + req: &CreateCertificateFromCSRRequest{ + CSRID: uuid.New(), + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + { + name: "missing CSR ID", + req: &CreateCertificateFromCSRRequest{ + Name: "Test Certificate from CSR", + CertificateAuthorityID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + { + name: "missing CA ID", + req: &CreateCertificateFromCSRRequest{ + Name: "Test Certificate from CSR", + CSRID: uuid.New(), + Type: models.CertificateTypeWeb, + }, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := service.validateCreateCertificateFromCSRRequest(test.req) + if (err != nil) != test.wantErr { + t.Errorf("validateCreateCertificateFromCSRRequest() error = %v, wantErr %v", err, test.wantErr) + } + }) + } +} + +func TestCertificateRequestStructures(t *testing.T) { + // Test CreateCertificateRequest structure + req := &CreateCertificateRequest{ + Name: "Test Certificate", + Description: "Test Description", + CommonName: "test.example.com", + Organization: "Test Organization", + OrganizationalUnit: "Test Unit", + Country: "DE", + State: "Bavaria", + Locality: "Munich", + Street: "Test Street", + Address: "Test Address", + PostalCode: "80331", + Email: "test@example.com", + Type: models.CertificateTypeWeb, + CertificateAuthorityID: uuid.New(), + NotAfter: func() *time.Time { t := time.Now().AddDate(1, 0, 0); return &t }(), + DNSNames: []string{"test.example.com", "www.test.example.com"}, + IPAddresses: []net.IP{net.ParseIP("192.168.1.1")}, + EmailAddresses: []string{"test@example.com"}, + URIs: []string{"https://test.example.com"}, + KeyType: "rsa", + KeySize: 2048, + Curve: "P-256", + KeyUsage: &KeyUsageConfig{ + DigitalSignature: true, + KeyEncipherment: true, + }, + ExtendedKeyUsage: &ExtendedKeyUsageConfig{ + ServerAuth: true, + }, + } + + if req.Name == "" { + t.Error("Certificate request name should not be empty") + } + if req.CommonName == "" { + t.Error("Certificate request common name should not be empty") + } + if req.CertificateAuthorityID == uuid.Nil { + t.Error("Certificate request CA ID should not be nil") + } + if req.Type == "" { + t.Error("Certificate request type should not be empty") + } + if len(req.DNSNames) == 0 { + t.Error("Certificate request should have DNS names") + } + if len(req.IPAddresses) == 0 { + t.Error("Certificate request should have IP addresses") + } + if req.KeyUsage == nil { + t.Error("Certificate request should have key usage") + } + if req.ExtendedKeyUsage == nil { + t.Error("Certificate request should have extended key usage") + } +} + +func TestKeyUsageConfig(t *testing.T) { + keyUsage := &KeyUsageConfig{ + DigitalSignature: true, + KeyEncipherment: true, + } + + if !keyUsage.DigitalSignature { + t.Error("Digital signature should be enabled") + } + if !keyUsage.KeyEncipherment { + t.Error("Key encipherment should be enabled") + } +} + +func TestExtendedKeyUsageConfig(t *testing.T) { + extKeyUsage := &ExtendedKeyUsageConfig{ + ServerAuth: true, + ClientAuth: false, + } + + if !extKeyUsage.ServerAuth { + t.Error("Server auth should be enabled") + } + if extKeyUsage.ClientAuth { + t.Error("Client auth should be disabled") + } +} + +func TestCertificateServiceCreation(t *testing.T) { + // Test that we can create a service instance + service := NewCertificateService(nil, "/tmp/certs", "/tmp/private", nil) + + if service == nil { + t.Error("CertificateService should not be nil") + } + + if service.certDir != "/tmp/certs" { + t.Errorf("Expected certDir to be /tmp/certs, got %s", service.certDir) + } + + if service.privateDir != "/tmp/private" { + t.Errorf("Expected privateDir to be /tmp/private, got %s", service.privateDir) + } +} + +func TestCertificateTypes(t *testing.T) { + // Test that all certificate types are defined + certTypes := []models.CertificateType{ + models.CertificateTypeWeb, + models.CertificateTypeServer, + models.CertificateTypeClient, + models.CertificateTypeUser, + models.CertificateTypeEmail, + models.CertificateTypeCode, + models.CertificateTypeCA, + models.CertificateTypeIoT, + models.CertificateTypeDevice, + models.CertificateTypeSensor, + models.CertificateTypeVPN, + models.CertificateTypeOpenVPN, + models.CertificateTypeWireGuard, + models.CertificateTypeDatabase, + models.CertificateTypeMySQL, + models.CertificateTypePostgreSQL, + models.CertificateTypeMongoDB, + models.CertificateTypeAPI, + models.CertificateTypeService, + models.CertificateTypeMicroservice, + models.CertificateTypeDocker, + models.CertificateTypeKubernetes, + models.CertificateTypeContainer, + models.CertificateTypeCloud, + models.CertificateTypeAWS, + models.CertificateTypeAzure, + models.CertificateTypeGCP, + models.CertificateTypeNetwork, + models.CertificateTypeFirewall, + models.CertificateTypeProxy, + models.CertificateTypeLoadBalancer, + models.CertificateTypeMobile, + models.CertificateTypeAndroid, + models.CertificateTypeiOS, + models.CertificateTypeApp, + models.CertificateTypeDocument, + models.CertificateTypePDF, + models.CertificateTypeOffice, + models.CertificateTypeTimestamp, + models.CertificateTypeOCSP, + models.CertificateTypeCustom, + models.CertificateTypeSpecial, + } + + for _, certType := range certTypes { + if certType == "" { + t.Errorf("Certificate type should not be empty") + } + } +} + +func TestCertificateTypeCategories(t *testing.T) { + // Test web certificates + webTypes := []models.CertificateType{ + models.CertificateTypeWeb, + models.CertificateTypeServer, + } + + for _, certType := range webTypes { + if certType == "" { + t.Errorf("Web certificate type should not be empty: %s", certType) + } + } + + // Test client certificates + clientTypes := []models.CertificateType{ + models.CertificateTypeClient, + models.CertificateTypeUser, + } + + for _, certType := range clientTypes { + if certType == "" { + t.Errorf("Client certificate type should not be empty: %s", certType) + } + } + + // Test IoT certificates + iotTypes := []models.CertificateType{ + models.CertificateTypeIoT, + models.CertificateTypeDevice, + models.CertificateTypeSensor, + } + + for _, certType := range iotTypes { + if certType == "" { + t.Errorf("IoT certificate type should not be empty: %s", certType) + } + } + + // Test VPN certificates + vpnTypes := []models.CertificateType{ + models.CertificateTypeVPN, + models.CertificateTypeOpenVPN, + models.CertificateTypeWireGuard, + } + + for _, certType := range vpnTypes { + if certType == "" { + t.Errorf("VPN certificate type should not be empty: %s", certType) + } + } + + // Test database certificates + dbTypes := []models.CertificateType{ + models.CertificateTypeDatabase, + models.CertificateTypeMySQL, + models.CertificateTypePostgreSQL, + models.CertificateTypeMongoDB, + } + + for _, certType := range dbTypes { + if certType == "" { + t.Errorf("Database certificate type should not be empty: %s", certType) + } + } +} diff --git a/certificate/utils/certificate.go b/certificate/utils/certificate.go new file mode 100644 index 0000000..6c1e313 --- /dev/null +++ b/certificate/utils/certificate.go @@ -0,0 +1,783 @@ +package utils + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/url" + "strings" + "time" + + "git.secnex.io/secnex/certman/models" +) + +// KeyType represents the type of cryptographic key +type KeyType string + +const ( + KeyTypeRSA KeyType = "rsa" + KeyTypeECDSA KeyType = "ecdsa" +) + +// KeySize represents the size of RSA keys +type KeySize int + +const ( + KeySize2048 KeySize = 2048 + KeySize3072 KeySize = 3072 + KeySize4096 KeySize = 4096 +) + +// Curve represents ECDSA curves +type Curve string + +const ( + CurveP256 Curve = "P-256" + CurveP384 Curve = "P-384" + CurveP521 Curve = "P-521" +) + +// CertificateConfig holds configuration for certificate generation +type CertificateConfig struct { + // Basic Information + CommonName string + Organization string + OrganizationalUnit string + Country string + State string + Locality string + Street string + PostalCode string + Email string + + // Certificate Details + SerialNumber *big.Int + NotBefore time.Time + NotAfter time.Time + CertificateType models.CertificateType + + // Subject Alternative Names + DNSNames []string + IPAddresses []net.IP + EmailAddresses []string + URIs []string + + // Key Configuration + KeyType KeyType + KeySize KeySize // Only used for RSA + Curve Curve // Only used for ECDSA + + // CA Configuration + IsCA bool + MaxPathLen int + MaxPathLenZero bool + BasicConstraintsValid bool + + // Key Usage + DigitalSignature bool + ContentCommitment bool + KeyEncipherment bool + DataEncipherment bool + KeyAgreement bool + KeyCertSign bool + CRLSign bool + EncipherOnly bool + DecipherOnly bool + + // Extended Key Usage + ServerAuth bool + ClientAuth bool + CodeSigning bool + EmailProtection bool + TimeStamping bool + OCSPSigning bool + + // Additional Extensions + SubjectKeyID bool + AuthorityKeyID bool + CRLDistributionPoints []string + OCSPServers []string + + // iOS-specific extensions + AuthorityInfoAccess []string // AIA URLs for iOS compatibility + IssuingCertificateURL string // URL to issuing certificate + + // Validity period configuration + ValidityYears int // Custom validity period in years (overrides default) + + // Platform compatibility + UseSHA256ForKeyIDs bool // Use SHA-256 for Key IDs (iOS compatible, MacBook compatible) + +} + +// DefaultCertificateConfig returns a default configuration for the given certificate type +func DefaultCertificateConfig(certType models.CertificateType) *CertificateConfig { + config := &CertificateConfig{ + CommonName: "", + Organization: "", + OrganizationalUnit: "", + Country: "DE", + State: "", + Locality: "", + Street: "", + PostalCode: "", + Email: "", + + SerialNumber: generateSerialNumber(), + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 3, 0), // 3 months default + CertificateType: certType, + + DNSNames: []string{}, + IPAddresses: []net.IP{}, + EmailAddresses: []string{}, + URIs: []string{}, + + KeyType: KeyTypeRSA, + KeySize: KeySize2048, + Curve: CurveP256, + + IsCA: false, + MaxPathLen: -1, + MaxPathLenZero: false, + BasicConstraintsValid: true, + + SubjectKeyID: true, + AuthorityKeyID: true, + CRLDistributionPoints: []string{}, + OCSPServers: []string{}, + + // iOS-specific defaults + AuthorityInfoAccess: []string{}, + IssuingCertificateURL: "", + + // Validity period defaults + ValidityYears: 0, // 0 means use default + + // Platform compatibility defaults + UseSHA256ForKeyIDs: true, // Use SHA-256 for modern compatibility (iOS + MacBook) + + } + + // Set defaults based on certificate type + switch certType { + case models.CertificateTypeCA: + config.IsCA = true + config.KeyCertSign = true + config.CRLSign = true + config.NotAfter = calculateValidityPeriod(config.NotBefore, config.ValidityYears, 10) // Default 10 years for CA + config.MaxPathLen = 0 + + case models.CertificateTypeWeb, models.CertificateTypeServer: + config.DigitalSignature = true + config.KeyEncipherment = true + config.ServerAuth = true + config.NotAfter = calculateValidityPeriodMonths(config.NotBefore, config.ValidityYears, 3) // Default 3 months for web/server + + // Enterprise defaults: Use larger key size if not specified + if config.KeySize == KeySize2048 && config.ValidityYears > 1 { + config.KeySize = KeySize4096 // Use 4096 bit for enterprise certificates + } + + // iOS-compatibility: Always ensure Common Name is in SAN + if config.CommonName != "" { + found := false + for _, dns := range config.DNSNames { + if dns == config.CommonName { + found = true + break + } + } + if !found { + config.DNSNames = append(config.DNSNames, config.CommonName) + } + } + + // iOS-compatibility: Add Authority Information Access + config.AuthorityInfoAccess = []string{"http://ca.secnex.internal/ca.crt"} + config.IssuingCertificateURL = "http://ca.secnex.internal/ca.crt" + + case models.CertificateTypeClient: + config.DigitalSignature = true + config.KeyEncipherment = true + config.ClientAuth = true + config.NotAfter = calculateValidityPeriod(config.NotBefore, config.ValidityYears, 1) // Default 1 year for client + + case models.CertificateTypeEmail: + config.DigitalSignature = true + config.KeyEncipherment = true + config.EmailProtection = true + config.NotAfter = calculateValidityPeriod(config.NotBefore, config.ValidityYears, 1) // Default 1 year for email + + case models.CertificateTypeCode: + config.DigitalSignature = true + config.CodeSigning = true + config.NotAfter = calculateValidityPeriod(config.NotBefore, config.ValidityYears, 2) // Default 2 years for code signing + } + + return config +} + +// CACertificateConfig returns a configuration specifically for CA certificates +func CACertificateConfig(isRoot bool) *CertificateConfig { + config := DefaultCertificateConfig(models.CertificateTypeCA) + config.IsCA = true + config.KeyCertSign = true + config.CRLSign = true + + if isRoot { + config.NotAfter = calculateValidityPeriod(config.NotBefore, config.ValidityYears, 20) // Default 20 years for root CA + config.MaxPathLen = -1 // No limit for root CA + } else { + config.NotAfter = calculateValidityPeriod(config.NotBefore, config.ValidityYears, 10) // Default 10 years for intermediate CA + config.MaxPathLen = 0 // Intermediate CA cannot issue other CAs by default + } + + return config +} + +// CACertificateConfigWithValidity returns a CA configuration with custom validity period +func CACertificateConfigWithValidity(isRoot bool, validityYears int) *CertificateConfig { + config := CACertificateConfig(isRoot) + config.ValidityYears = validityYears + + if isRoot { + config.NotAfter = calculateValidityPeriod(config.NotBefore, validityYears, 20) + } else { + config.NotAfter = calculateValidityPeriod(config.NotBefore, validityYears, 10) + } + + return config +} + +// WebServerCertificateConfig returns a configuration for web server certificates +func WebServerCertificateConfig(domains []string) *CertificateConfig { + config := DefaultCertificateConfig(models.CertificateTypeWeb) + config.DNSNames = domains + config.DigitalSignature = true + config.KeyEncipherment = true + config.ServerAuth = true + config.NotAfter = config.NotBefore.AddDate(0, 3, 0) // 3 months for web certificates + return config +} + +// ClientCertificateConfig returns a configuration for client certificates +func ClientCertificateConfig() *CertificateConfig { + config := DefaultCertificateConfig(models.CertificateTypeClient) + config.DigitalSignature = true + config.KeyEncipherment = true + config.ClientAuth = true + return config +} + +// EmailCertificateConfig returns a configuration for email certificates +func EmailCertificateConfig(email string) *CertificateConfig { + config := DefaultCertificateConfig(models.CertificateTypeEmail) + config.EmailAddresses = []string{email} + config.DigitalSignature = true + config.KeyEncipherment = true + config.EmailProtection = true + return config +} + +// CodeSigningCertificateConfig returns a configuration for code signing certificates +func CodeSigningCertificateConfig() *CertificateConfig { + config := DefaultCertificateConfig(models.CertificateTypeCode) + config.DigitalSignature = true + config.CodeSigning = true + return config +} + +// CertificateGenerator handles certificate generation +type CertificateGenerator struct { + config *CertificateConfig +} + +// NewCertificateGenerator creates a new certificate generator +func NewCertificateGenerator(config *CertificateConfig) *CertificateGenerator { + return &CertificateGenerator{ + config: config, + } +} + +// GenerateKeyPair generates a key pair based on the configuration +func (cg *CertificateGenerator) GenerateKeyPair() (interface{}, error) { + switch cg.config.KeyType { + case KeyTypeRSA: + return rsa.GenerateKey(rand.Reader, int(cg.config.KeySize)) + case KeyTypeECDSA: + var curve elliptic.Curve + switch cg.config.Curve { + case CurveP256: + curve = elliptic.P256() + case CurveP384: + curve = elliptic.P384() + case CurveP521: + curve = elliptic.P521() + default: + return nil, fmt.Errorf("unsupported ECDSA curve: %s", cg.config.Curve) + } + return ecdsa.GenerateKey(curve, rand.Reader) + default: + return nil, fmt.Errorf("unsupported key type: %s", cg.config.KeyType) + } +} + +// GenerateSelfSignedCertificate generates a self-signed certificate +func (cg *CertificateGenerator) GenerateSelfSignedCertificate() (*x509.Certificate, interface{}, error) { + // Generate key pair + privateKey, err := cg.GenerateKeyPair() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate key pair: %w", err) + } + + // Create certificate template + template := cg.createCertificateTemplate() + + // Generate certificate + certDER, err := x509.CreateCertificate(rand.Reader, template, template, cg.getPublicKey(privateKey), privateKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Parse certificate + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return cert, privateKey, nil +} + +// GenerateCertificate generates a certificate signed by a CA +func (cg *CertificateGenerator) GenerateCertificate(caCert *x509.Certificate, caPrivateKey interface{}) (*x509.Certificate, interface{}, error) { + // Generate key pair + privateKey, err := cg.GenerateKeyPair() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate key pair: %w", err) + } + + // Create certificate template + template := cg.createCertificateTemplate() + + // Generate certificate + certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, cg.getPublicKey(privateKey), caPrivateKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Parse certificate + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return cert, privateKey, nil +} + +// createCertificateTemplate creates an x509.Certificate template from the configuration +func (cg *CertificateGenerator) createCertificateTemplate() *x509.Certificate { + template := &x509.Certificate{ + SerialNumber: cg.config.SerialNumber, + Subject: pkix.Name{ + CommonName: cg.config.CommonName, + Organization: []string{cg.config.Organization}, + OrganizationalUnit: []string{cg.config.OrganizationalUnit}, + Country: []string{cg.config.Country}, + Province: []string{cg.config.State}, + Locality: []string{cg.config.Locality}, + StreetAddress: []string{cg.config.Street}, + PostalCode: []string{cg.config.PostalCode}, + }, + NotBefore: cg.config.NotBefore, + NotAfter: cg.config.NotAfter, + DNSNames: cg.config.DNSNames, + IPAddresses: cg.config.IPAddresses, + EmailAddresses: cg.config.EmailAddresses, + URIs: parseURIs(cg.config.URIs), + } + + // Set key usage + var keyUsage x509.KeyUsage + if cg.config.DigitalSignature { + keyUsage |= x509.KeyUsageDigitalSignature + } + if cg.config.ContentCommitment { + keyUsage |= x509.KeyUsageContentCommitment + } + if cg.config.KeyEncipherment { + keyUsage |= x509.KeyUsageKeyEncipherment + } + if cg.config.DataEncipherment { + keyUsage |= x509.KeyUsageDataEncipherment + } + if cg.config.KeyAgreement { + keyUsage |= x509.KeyUsageKeyAgreement + } + if cg.config.KeyCertSign { + keyUsage |= x509.KeyUsageCertSign + } + if cg.config.CRLSign { + keyUsage |= x509.KeyUsageCRLSign + } + if cg.config.EncipherOnly { + keyUsage |= x509.KeyUsageEncipherOnly + } + if cg.config.DecipherOnly { + keyUsage |= x509.KeyUsageDecipherOnly + } + template.KeyUsage = keyUsage + + // Set extended key usage + var extKeyUsage []x509.ExtKeyUsage + if cg.config.ServerAuth { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageServerAuth) + } + if cg.config.ClientAuth { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageClientAuth) + } + if cg.config.CodeSigning { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageCodeSigning) + } + if cg.config.EmailProtection { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageEmailProtection) + } + if cg.config.TimeStamping { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageTimeStamping) + } + if cg.config.OCSPSigning { + extKeyUsage = append(extKeyUsage, x509.ExtKeyUsageOCSPSigning) + } + template.ExtKeyUsage = extKeyUsage + + // Set basic constraints + if cg.config.BasicConstraintsValid { + template.BasicConstraintsValid = true + template.IsCA = cg.config.IsCA + if cg.config.IsCA { + template.MaxPathLen = cg.config.MaxPathLen + template.MaxPathLenZero = cg.config.MaxPathLenZero + } + } + + // Set subject key identifier + if cg.config.SubjectKeyID { + template.SubjectKeyId = generateSubjectKeyID(cg.getPublicKey(nil)) + } + + // Set authority key identifier + if cg.config.AuthorityKeyID { + template.AuthorityKeyId = generateAuthorityKeyID(cg.getPublicKey(nil)) + } + + // Add CRL distribution points + if len(cg.config.CRLDistributionPoints) > 0 { + template.CRLDistributionPoints = cg.config.CRLDistributionPoints + } + + // Add OCSP servers + if len(cg.config.OCSPServers) > 0 { + template.OCSPServer = cg.config.OCSPServers + } + + // Add Authority Information Access (AIA) for iOS compatibility + if len(cg.config.AuthorityInfoAccess) > 0 { + template.IssuingCertificateURL = cg.config.AuthorityInfoAccess + } + + return template +} + +// getPublicKey extracts the public key from a private key +func (cg *CertificateGenerator) getPublicKey(privateKey interface{}) interface{} { + if privateKey == nil { + return nil + } + + switch k := privateKey.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +} + +// CertificateExporter handles certificate export functionality +type CertificateExporter struct{} + +// NewCertificateExporter creates a new certificate exporter +func NewCertificateExporter() *CertificateExporter { + return &CertificateExporter{} +} + +// ExportCertificateToPEM exports a certificate to PEM format +func (ce *CertificateExporter) ExportCertificateToPEM(cert *x509.Certificate) ([]byte, error) { + return pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }), nil +} + +// ExportPrivateKeyToPEM exports a private key to PEM format +func (ce *CertificateExporter) ExportPrivateKeyToPEM(privateKey interface{}) ([]byte, error) { + var pemType string + var keyBytes []byte + var err error + + switch k := privateKey.(type) { + case *rsa.PrivateKey: + pemType = "RSA PRIVATE KEY" + keyBytes = x509.MarshalPKCS1PrivateKey(k) + case *ecdsa.PrivateKey: + pemType = "EC PRIVATE KEY" + keyBytes, err = x509.MarshalECPrivateKey(k) + if err != nil { + return nil, fmt.Errorf("failed to marshal ECDSA private key: %w", err) + } + default: + return nil, fmt.Errorf("unsupported private key type") + } + + return pem.EncodeToMemory(&pem.Block{ + Type: pemType, + Bytes: keyBytes, + }), nil +} + +// ExportCertificateToDER exports a certificate to DER format +func (ce *CertificateExporter) ExportCertificateToDER(cert *x509.Certificate) []byte { + return cert.Raw +} + +// ExportPrivateKeyToDER exports a private key to DER format +func (ce *CertificateExporter) ExportPrivateKeyToDER(privateKey interface{}) ([]byte, error) { + switch k := privateKey.(type) { + case *rsa.PrivateKey: + return x509.MarshalPKCS1PrivateKey(k), nil + case *ecdsa.PrivateKey: + return x509.MarshalECPrivateKey(k) + default: + return nil, fmt.Errorf("unsupported private key type") + } +} + +// Utility functions + +// calculateValidityPeriod calculates the validity period based on custom years or default +func calculateValidityPeriod(notBefore time.Time, customYears int, defaultYears int) time.Time { + if customYears > 0 { + return notBefore.AddDate(customYears, 0, 0) + } + return notBefore.AddDate(defaultYears, 0, 0) +} + +// calculateValidityPeriodMonths calculates the validity period based on custom years or default months +func calculateValidityPeriodMonths(notBefore time.Time, customYears int, defaultMonths int) time.Time { + if customYears > 0 { + return notBefore.AddDate(customYears, 0, 0) + } + return notBefore.AddDate(0, defaultMonths, 0) +} + +// generateSerialNumber generates a random serial number +func generateSerialNumber() *big.Int { + serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + return serial +} + +// generateSubjectKeyID generates a subject key identifier using SHA-256 (modern compatibility) +func generateSubjectKeyID(publicKey interface{}) []byte { + if publicKey == nil { + return nil + } + + var publicKeyBytes []byte + var err error + + switch k := publicKey.(type) { + case *rsa.PublicKey: + publicKeyBytes, err = asn1.Marshal(k) + case *ecdsa.PublicKey: + publicKeyBytes, err = asn1.Marshal(k) + default: + return nil + } + + if err != nil { + return nil + } + + // Use SHA-256 for modern compatibility (iOS + MacBook + modern browsers) + hash := sha256.Sum256(publicKeyBytes) + return hash[:] +} + +// generateAuthorityKeyID generates an authority key identifier +func generateAuthorityKeyID(publicKey interface{}) []byte { + return generateSubjectKeyID(publicKey) +} + +// parseURIs converts string URIs to *url.URL +func parseURIs(uriStrings []string) []*url.URL { + var uris []*url.URL + for _, uriStr := range uriStrings { + if uri, err := url.Parse(uriStr); err == nil { + uris = append(uris, uri) + } + } + return uris +} + +// ParseSANs parses Subject Alternative Names from a string +func ParseSANs(sans string) ([]string, []net.IP, []string, []string) { + var dnsNames []string + var ipAddresses []net.IP + var emailAddresses []string + var uris []string + + if sans == "" { + return dnsNames, ipAddresses, emailAddresses, uris + } + + parts := strings.Split(sans, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + // Check if it's an IP address + if ip := net.ParseIP(part); ip != nil { + ipAddresses = append(ipAddresses, ip) + continue + } + + // Check if it's an email address + if strings.Contains(part, "@") { + emailAddresses = append(emailAddresses, part) + continue + } + + // Check if it's a URI + if strings.HasPrefix(part, "http://") || strings.HasPrefix(part, "https://") { + uris = append(uris, part) + continue + } + + // Default to DNS name + dnsNames = append(dnsNames, part) + } + + return dnsNames, ipAddresses, emailAddresses, uris +} + +// ValidateCertificateConfig validates a certificate configuration +func ValidateCertificateConfig(config *CertificateConfig) error { + if config.CommonName == "" { + return fmt.Errorf("common name is required") + } + + if config.Organization == "" { + return fmt.Errorf("organization is required") + } + + if config.Country == "" { + return fmt.Errorf("country is required") + } + + if config.NotBefore.After(config.NotAfter) { + return fmt.Errorf("not before date must be before not after date") + } + + if config.IsCA && !config.KeyCertSign { + return fmt.Errorf("CA certificates must have key cert sign usage") + } + + // Modern browser and iOS validations + if config.CertificateType == models.CertificateTypeWeb { + // Ensure Common Name is in SAN (required by modern browsers and iOS) + if config.CommonName != "" { + found := false + for _, dns := range config.DNSNames { + if dns == config.CommonName { + found = true + break + } + } + if !found { + return fmt.Errorf("Common Name must be included in Subject Alternative Names for modern browser compatibility") + } + } + + // Validate key size (minimum 2048 bits required) + if config.KeyType == KeyTypeRSA && config.KeySize < KeySize2048 { + return fmt.Errorf("RSA keys must be at least 2048 bits") + } + + // Validate validity period (maximum 825 days, recommended 398 days) + validityDays := int(config.NotAfter.Sub(config.NotBefore).Hours() / 24) + if validityDays > 825 { + return fmt.Errorf("certificate validity period exceeds 825 days (current: %d days)", validityDays) + } + } + + return nil +} + +// CreateCertificateFromModel creates a certificate configuration from a model +func CreateCertificateFromModel(cert *models.Certificate) *CertificateConfig { + config := DefaultCertificateConfig(cert.Type) + + config.CommonName = cert.AttributeCommonName + config.Organization = cert.AttributeOrganization + config.OrganizationalUnit = cert.AttributeOrganizationUnit + config.Country = cert.AttributeCountry + config.State = cert.AttributeState + config.Locality = cert.AttributeLocality + config.Street = cert.AttributeStreet + config.PostalCode = cert.AttributePostalCode + config.Email = cert.AttributeEmail + config.NotBefore = cert.AttributeNotBefore + config.NotAfter = cert.AttributeNotAfter + + // Parse SANs + if cert.AttributeSubjectAlternativeName != "" { + dnsNames, ipAddresses, emailAddresses, uris := ParseSANs(cert.AttributeSubjectAlternativeName) + config.DNSNames = dnsNames + config.IPAddresses = ipAddresses + config.EmailAddresses = emailAddresses + config.URIs = uris + } + + return config +} + +// CreateCAFromModel creates a CA configuration from a model +func CreateCAFromModel(ca *models.CertificateAuthority) *CertificateConfig { + config := CACertificateConfig(ca.Root) + + config.CommonName = ca.AttributeCommonName + config.Organization = ca.AttributeOrganization + config.OrganizationalUnit = ca.AttributeOrganizationUnit + config.Country = ca.AttributeCountry + config.State = ca.AttributeState + config.Locality = ca.AttributeLocality + config.Street = ca.AttributeStreet + config.PostalCode = ca.AttributePostalCode + config.Email = ca.AttributeEmail + config.NotBefore = ca.AttributeNotBefore + config.NotAfter = ca.AttributeNotAfter + + return config +} diff --git a/certificate/utils/certificate_test.go b/certificate/utils/certificate_test.go new file mode 100644 index 0000000..0c97f19 --- /dev/null +++ b/certificate/utils/certificate_test.go @@ -0,0 +1,537 @@ +package utils + +import ( + "crypto/x509" + "testing" + "time" + + "git.secnex.io/secnex/certman/models" +) + +func TestDefaultCertificateConfig(t *testing.T) { + tests := []struct { + certType models.CertificateType + expected bool + }{ + {models.CertificateTypeCA, true}, + {models.CertificateTypeWeb, true}, + {models.CertificateTypeClient, true}, + {models.CertificateTypeEmail, true}, + {models.CertificateTypeCode, true}, + {models.CertificateTypeServer, true}, + } + + for _, test := range tests { + config := DefaultCertificateConfig(test.certType) + if config.CertificateType != test.certType { + t.Errorf("Expected certificate type %s, got %s", test.certType, config.CertificateType) + } + if config.SerialNumber == nil { + t.Error("Serial number should not be nil") + } + if config.NotBefore.After(config.NotAfter) { + t.Error("NotBefore should be before NotAfter") + } + } +} + +func TestCACertificateConfig(t *testing.T) { + // Test root CA + rootConfig := CACertificateConfig(true) + if !rootConfig.IsCA { + t.Error("Root CA should have IsCA = true") + } + if !rootConfig.KeyCertSign { + t.Error("Root CA should have KeyCertSign = true") + } + if !rootConfig.CRLSign { + t.Error("Root CA should have CRLSign = true") + } + if rootConfig.MaxPathLen != -1 { + t.Error("Root CA should have MaxPathLen = -1") + } + + // Test intermediate CA + intermediateConfig := CACertificateConfig(false) + if !intermediateConfig.IsCA { + t.Error("Intermediate CA should have IsCA = true") + } + if !intermediateConfig.KeyCertSign { + t.Error("Intermediate CA should have KeyCertSign = true") + } + if !intermediateConfig.CRLSign { + t.Error("Intermediate CA should have CRLSign = true") + } + if intermediateConfig.MaxPathLen != 0 { + t.Error("Intermediate CA should have MaxPathLen = 0") + } +} + +func TestWebServerCertificateConfig(t *testing.T) { + domains := []string{"example.com", "www.example.com"} + config := WebServerCertificateConfig(domains) + + if config.CertificateType != models.CertificateTypeWeb { + t.Error("Should be web certificate type") + } + if !config.DigitalSignature { + t.Error("Should have digital signature") + } + if !config.KeyEncipherment { + t.Error("Should have key encipherment") + } + if !config.ServerAuth { + t.Error("Should have server auth") + } + if len(config.DNSNames) != len(domains) { + t.Error("DNS names should match input domains") + } +} + +func TestClientCertificateConfig(t *testing.T) { + config := ClientCertificateConfig() + + if config.CertificateType != models.CertificateTypeClient { + t.Error("Should be client certificate type") + } + if !config.DigitalSignature { + t.Error("Should have digital signature") + } + if !config.KeyEncipherment { + t.Error("Should have key encipherment") + } + if !config.ClientAuth { + t.Error("Should have client auth") + } +} + +func TestEmailCertificateConfig(t *testing.T) { + email := "test@example.com" + config := EmailCertificateConfig(email) + + if config.CertificateType != models.CertificateTypeEmail { + t.Error("Should be email certificate type") + } + if !config.DigitalSignature { + t.Error("Should have digital signature") + } + if !config.KeyEncipherment { + t.Error("Should have key encipherment") + } + if !config.EmailProtection { + t.Error("Should have email protection") + } + if len(config.EmailAddresses) != 1 || config.EmailAddresses[0] != email { + t.Error("Email addresses should match input") + } +} + +func TestCodeSigningCertificateConfig(t *testing.T) { + config := CodeSigningCertificateConfig() + + if config.CertificateType != models.CertificateTypeCode { + t.Error("Should be code signing certificate type") + } + if !config.DigitalSignature { + t.Error("Should have digital signature") + } + if !config.CodeSigning { + t.Error("Should have code signing") + } +} + +func TestGenerateKeyPair(t *testing.T) { + tests := []struct { + keyType KeyType + keySize KeySize + curve Curve + }{ + {KeyTypeRSA, KeySize2048, CurveP256}, + {KeyTypeRSA, KeySize3072, CurveP256}, + {KeyTypeRSA, KeySize4096, CurveP256}, + {KeyTypeECDSA, KeySize2048, CurveP256}, + {KeyTypeECDSA, KeySize2048, CurveP384}, + {KeyTypeECDSA, KeySize2048, CurveP521}, + } + + for _, test := range tests { + config := DefaultCertificateConfig(models.CertificateTypeWeb) + config.KeyType = test.keyType + config.KeySize = test.keySize + config.Curve = test.curve + + generator := NewCertificateGenerator(config) + key, err := generator.GenerateKeyPair() + if err != nil { + t.Errorf("Failed to generate key pair for %s: %v", test.keyType, err) + } + if key == nil { + t.Errorf("Generated key should not be nil for %s", test.keyType) + } + } +} + +func TestGenerateSelfSignedCertificate(t *testing.T) { + config := CACertificateConfig(true) + config.CommonName = "Test Root CA" + config.Organization = "Test Org" + config.Country = "DE" + + generator := NewCertificateGenerator(config) + cert, privateKey, err := generator.GenerateSelfSignedCertificate() + if err != nil { + t.Fatalf("Failed to generate self-signed certificate: %v", err) + } + + if cert == nil { + t.Error("Certificate should not be nil") + } + if privateKey == nil { + t.Error("Private key should not be nil") + } + if !cert.IsCA { + t.Error("CA certificate should have IsCA = true") + } + if cert.Subject.CommonName != config.CommonName { + t.Error("Certificate subject should match configuration") + } +} + +func TestGenerateCertificate(t *testing.T) { + // Generate root CA first + rootConfig := CACertificateConfig(true) + rootConfig.CommonName = "Test Root CA" + rootConfig.Organization = "Test Org" + rootConfig.Country = "DE" + + rootGenerator := NewCertificateGenerator(rootConfig) + rootCert, rootPrivateKey, err := rootGenerator.GenerateSelfSignedCertificate() + if err != nil { + t.Fatalf("Failed to generate root CA: %v", err) + } + + // Generate intermediate CA + intermediateConfig := CACertificateConfig(false) + intermediateConfig.CommonName = "Test Intermediate CA" + intermediateConfig.Organization = "Test Org" + intermediateConfig.Country = "DE" + + intermediateGenerator := NewCertificateGenerator(intermediateConfig) + intermediateCert, intermediatePrivateKey, err := intermediateGenerator.GenerateCertificate(rootCert, rootPrivateKey) + if err != nil { + t.Fatalf("Failed to generate intermediate CA: %v", err) + } + + if intermediateCert == nil { + t.Error("Intermediate certificate should not be nil") + } + if intermediatePrivateKey == nil { + t.Error("Intermediate private key should not be nil") + } + if !intermediateCert.IsCA { + t.Error("Intermediate CA certificate should have IsCA = true") + } + if intermediateCert.Issuer.CommonName != rootCert.Subject.CommonName { + t.Error("Intermediate certificate issuer should match root certificate subject") + } + + // Generate end entity certificate + endConfig := WebServerCertificateConfig([]string{"example.com"}) + endConfig.CommonName = "example.com" + endConfig.Organization = "Test Org" + endConfig.Country = "DE" + + endGenerator := NewCertificateGenerator(endConfig) + endCert, endPrivateKey, err := endGenerator.GenerateCertificate(intermediateCert, intermediatePrivateKey) + if err != nil { + t.Fatalf("Failed to generate end entity certificate: %v", err) + } + + if endCert == nil { + t.Error("End entity certificate should not be nil") + } + if endPrivateKey == nil { + t.Error("End entity private key should not be nil") + } + if endCert.IsCA { + t.Error("End entity certificate should not have IsCA = true") + } + if endCert.Issuer.CommonName != intermediateCert.Subject.CommonName { + t.Error("End entity certificate issuer should match intermediate certificate subject") + } +} + +func TestValidateCertificateConfig(t *testing.T) { + tests := []struct { + name string + config *CertificateConfig + wantErr bool + }{ + { + name: "valid config", + config: &CertificateConfig{ + CommonName: "example.com", + Organization: "Test Org", + Country: "DE", + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + }, + wantErr: false, + }, + { + name: "missing common name", + config: &CertificateConfig{ + Organization: "Test Org", + Country: "DE", + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + }, + wantErr: true, + }, + { + name: "missing organization", + config: &CertificateConfig{ + CommonName: "example.com", + Country: "DE", + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + }, + wantErr: true, + }, + { + name: "missing country", + config: &CertificateConfig{ + CommonName: "example.com", + Organization: "Test Org", + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + }, + wantErr: true, + }, + { + name: "invalid date range", + config: &CertificateConfig{ + CommonName: "example.com", + Organization: "Test Org", + Country: "DE", + NotBefore: time.Now().AddDate(1, 0, 0), + NotAfter: time.Now(), + }, + wantErr: true, + }, + { + name: "CA without key cert sign", + config: &CertificateConfig{ + CommonName: "example.com", + Organization: "Test Org", + Country: "DE", + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + IsCA: true, + KeyCertSign: false, + }, + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := ValidateCertificateConfig(test.config) + if (err != nil) != test.wantErr { + t.Errorf("ValidateCertificateConfig() error = %v, wantErr %v", err, test.wantErr) + } + }) + } +} + +func TestParseSANs(t *testing.T) { + tests := []struct { + input string + expectedDNS []string + expectedIPs int + expectedEmails []string + expectedURIs []string + }{ + { + input: "example.com,www.example.com", + expectedDNS: []string{"example.com", "www.example.com"}, + expectedIPs: 0, + expectedEmails: []string{}, + expectedURIs: []string{}, + }, + { + input: "example.com,192.168.1.1,user@example.com,https://example.com", + expectedDNS: []string{"example.com"}, + expectedIPs: 1, + expectedEmails: []string{"user@example.com"}, + expectedURIs: []string{"https://example.com"}, + }, + { + input: "", + expectedDNS: []string{}, + expectedIPs: 0, + expectedEmails: []string{}, + expectedURIs: []string{}, + }, + } + + for _, test := range tests { + dnsNames, ipAddresses, emailAddresses, uris := ParseSANs(test.input) + + if len(dnsNames) != len(test.expectedDNS) { + t.Errorf("Expected %d DNS names, got %d", len(test.expectedDNS), len(dnsNames)) + } + + if len(ipAddresses) != test.expectedIPs { + t.Errorf("Expected %d IP addresses, got %d", test.expectedIPs, len(ipAddresses)) + } + + if len(emailAddresses) != len(test.expectedEmails) { + t.Errorf("Expected %d email addresses, got %d", len(test.expectedEmails), len(emailAddresses)) + } + + if len(uris) != len(test.expectedURIs) { + t.Errorf("Expected %d URIs, got %d", len(test.expectedURIs), len(uris)) + } + } +} + +func TestExportCertificateToPEM(t *testing.T) { + // Generate a test certificate + config := CACertificateConfig(true) + config.CommonName = "Test CA" + config.Organization = "Test Org" + config.Country = "DE" + + generator := NewCertificateGenerator(config) + cert, _, err := generator.GenerateSelfSignedCertificate() + if err != nil { + t.Fatalf("Failed to generate certificate: %v", err) + } + + exporter := NewCertificateExporter() + pemData, err := exporter.ExportCertificateToPEM(cert) + if err != nil { + t.Fatalf("Failed to export certificate to PEM: %v", err) + } + + if len(pemData) == 0 { + t.Error("PEM data should not be empty") + } + + // Verify PEM format + if !contains(pemData, []byte("-----BEGIN CERTIFICATE-----")) { + t.Error("PEM data should contain certificate header") + } + if !contains(pemData, []byte("-----END CERTIFICATE-----")) { + t.Error("PEM data should contain certificate footer") + } +} + +func TestExportPrivateKeyToPEM(t *testing.T) { + // Generate a test certificate with private key + config := CACertificateConfig(true) + config.CommonName = "Test CA" + config.Organization = "Test Org" + config.Country = "DE" + + generator := NewCertificateGenerator(config) + _, privateKey, err := generator.GenerateSelfSignedCertificate() + if err != nil { + t.Fatalf("Failed to generate certificate: %v", err) + } + + exporter := NewCertificateExporter() + pemData, err := exporter.ExportPrivateKeyToPEM(privateKey) + if err != nil { + t.Fatalf("Failed to export private key to PEM: %v", err) + } + + if len(pemData) == 0 { + t.Error("PEM data should not be empty") + } + + // Verify PEM format + if !contains(pemData, []byte("-----BEGIN")) { + t.Error("PEM data should contain private key header") + } + if !contains(pemData, []byte("-----END")) { + t.Error("PEM data should contain private key footer") + } +} + +func TestExportCertificateToDER(t *testing.T) { + // Generate a test certificate + config := CACertificateConfig(true) + config.CommonName = "Test CA" + config.Organization = "Test Org" + config.Country = "DE" + + generator := NewCertificateGenerator(config) + cert, _, err := generator.GenerateSelfSignedCertificate() + if err != nil { + t.Fatalf("Failed to generate certificate: %v", err) + } + + exporter := NewCertificateExporter() + derData := exporter.ExportCertificateToDER(cert) + + if len(derData) == 0 { + t.Error("DER data should not be empty") + } + + // Verify DER format by parsing it + _, err = x509.ParseCertificate(derData) + if err != nil { + t.Errorf("DER data should be valid certificate: %v", err) + } +} + +func TestExportPrivateKeyToDER(t *testing.T) { + // Generate a test certificate with private key + config := CACertificateConfig(true) + config.CommonName = "Test CA" + config.Organization = "Test Org" + config.Country = "DE" + + generator := NewCertificateGenerator(config) + _, privateKey, err := generator.GenerateSelfSignedCertificate() + if err != nil { + t.Fatalf("Failed to generate certificate: %v", err) + } + + exporter := NewCertificateExporter() + derData, err := exporter.ExportPrivateKeyToDER(privateKey) + if err != nil { + t.Fatalf("Failed to export private key to DER: %v", err) + } + + if len(derData) == 0 { + t.Error("DER data should not be empty") + } +} + +// Helper function to check if a slice contains a subslice +func contains(s, subslice []byte) bool { + if len(subslice) > len(s) { + return false + } + for i := 0; i <= len(s)-len(subslice); i++ { + if bytesEqual(s[i:i+len(subslice)], subslice) { + return true + } + } + return false +} + +func bytesEqual(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} \ No newline at end of file