package middlewares import ( "crypto/subtle" "encoding/base64" "fmt" "net/http" "strings" "git.secnex.io/secnex/api-gateway/config" "git.secnex.io/secnex/api-gateway/res" "git.secnex.io/secnex/masterlog" "golang.org/x/crypto/argon2" ) // AuthMiddleware handles authentication based on header validation and path filtering type AuthMiddleware struct { header string authType string pathConfig config.AuthPath keys []config.ApiKey handler http.Handler } // NewAuthMiddleware creates a new authentication middleware func NewAuthMiddleware(header string, authType string, pathConfig config.AuthPath, keys []config.ApiKey, handler http.Handler) http.Handler { masterlog.Debug("Creating AuthMiddleware", map[string]interface{}{ "header": header, "type": authType, "include_paths": pathConfig.Include, "exclude_paths": pathConfig.Exclude, "keys_count": len(keys), }) return &AuthMiddleware{ header: header, authType: authType, pathConfig: pathConfig, keys: keys, handler: handler, } } // ServeHTTP handles the authentication logic func (m *AuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { requestPath := r.URL.Path // Step 1: Determine if this path requires authentication requiresAuth := m.requiresAuth(requestPath) masterlog.Debug("AuthMiddleware: Checking if path requires auth", map[string]interface{}{ "path": requestPath, "requires_auth": requiresAuth, "include": m.pathConfig.Include, "exclude": m.pathConfig.Exclude, }) if !requiresAuth { // No auth required, skip to next handler masterlog.Debug("AuthMiddleware: Skipping auth for path", map[string]interface{}{ "path": requestPath, "reason": "path_matches_exclude_or_not_include", }) m.handler.ServeHTTP(w, r) return } // Step 2: Check if auth header is present authHeader := r.Header.Get(m.header) if authHeader == "" { masterlog.Debug("AuthMiddleware: Missing auth header", map[string]interface{}{ "path": requestPath, "header": m.header, }) res.Unauthorized(w) return } switch m.authType { case "api_key": masterlog.Debug("AuthMiddleware: API key authentication", map[string]interface{}{ "path": requestPath, "header": m.header, }) apiKey := r.Header.Get(m.header) plainSecret, err := base64.StdEncoding.DecodeString(apiKey) if err != nil { masterlog.Debug("AuthMiddleware: API key authentication failed", map[string]interface{}{ "path": requestPath, "header": m.header, "error": err, }) res.Unauthorized(w) return } secret := strings.Split(string(plainSecret), ":") if len(secret) != 2 { masterlog.Debug("AuthMiddleware: API key authentication failed", map[string]interface{}{ "path": requestPath, "header": m.header, "error": "invalid_api_key_format", }) res.Unauthorized(w) return } // Find matching key by ID for _, key := range m.keys { if key.ID == secret[0] { // Verify argon2 hash if m.verifyArgon2Hash(secret[1], key.Key) { masterlog.Debug("AuthMiddleware: API key authentication successful", map[string]interface{}{ "path": requestPath, "key_id": key.ID, }) m.handler.ServeHTTP(w, r) return } } } masterlog.Debug("AuthMiddleware: API key authentication failed", map[string]interface{}{ "path": requestPath, "header": m.header, "error": "invalid_api_key", }) res.Unauthorized(w) return case "bearer_token": masterlog.Debug("AuthMiddleware: Bearer token authentication", map[string]interface{}{ "path": requestPath, "header": m.header, }) } r.Header.Set(m.header, "valid_api_key") masterlog.Debug("AuthMiddleware: Authentication successful", map[string]interface{}{ "path": requestPath, }) // Step 4: Forward to next handler m.handler.ServeHTTP(w, r) } // requiresAuth determines if a given path requires authentication // // Logic: // 1. If BOTH include and exclude are empty → auth required for ALL paths // 2. If ONLY include is set (non-empty) → auth required ONLY for paths matching include patterns // 3. If ONLY exclude is set (non-empty) → auth required for ALL paths EXCEPT those matching exclude patterns // 4. If BOTH are set → include takes precedence (auth required ONLY for paths matching include patterns) // // Wildcard patterns are supported: // - "*" matches any path // - "/api/*" matches "/api/" and any subpath like "/api/users", "/api/users/123" // - "/api/v1/public/test/*" matches "/test", "/test/123", etc. func (m *AuthMiddleware) requiresAuth(path string) bool { include := m.pathConfig.Include exclude := m.pathConfig.Exclude includeEmpty := len(include) == 0 excludeEmpty := len(exclude) == 0 masterlog.Debug("AuthMiddleware: Evaluating auth requirement", map[string]interface{}{ "path": path, "include_empty": includeEmpty, "exclude_empty": excludeEmpty, "include": include, "exclude": exclude, }) // Case 1: Both include and exclude are empty → auth required for ALL if includeEmpty && excludeEmpty { masterlog.Debug("AuthMiddleware: Both include/exclude empty, auth required for all", map[string]interface{}{ "path": path, }) return true } // Case 2: Only include is set → auth required ONLY for matching paths if !includeEmpty { for _, pattern := range include { if m.matchPattern(path, pattern) { masterlog.Debug("AuthMiddleware: Path matches include pattern", map[string]interface{}{ "path": path, "pattern": pattern, }) return true } } masterlog.Debug("AuthMiddleware: Path does not match any include pattern", map[string]interface{}{ "path": path, "patterns": include, }) return false } // Case 3: Only exclude is set (include is empty) → auth required EXCEPT for matching paths // This is also reached when both are set (include takes precedence above) for _, pattern := range exclude { if m.matchPattern(path, pattern) { masterlog.Debug("AuthMiddleware: Path matches exclude pattern", map[string]interface{}{ "path": path, "pattern": pattern, }) return false } } masterlog.Debug("AuthMiddleware: Path does not match any exclude pattern, auth required", map[string]interface{}{ "path": path, "patterns": exclude, }) return true } // matchPattern checks if a path matches a wildcard pattern // // Supported patterns: // - "*" matches any path // - "/api/*" matches "/api/" and any subpath // - "/api/v1/public/test/*" matches the exact prefix and any subpath // // The pattern matching is prefix-based. If the pattern ends with "*", // it matches any path that starts with the pattern (excluding the "*"). func (m *AuthMiddleware) matchPattern(path, pattern string) bool { // Wildcard: matches everything if pattern == "*" { masterlog.Debug("AuthMiddleware: Wildcard pattern matches", map[string]interface{}{ "path": path, }) return true } // Pattern ends with wildcard: prefix matching if strings.HasSuffix(pattern, "*") { prefix := strings.TrimSuffix(pattern, "*") matches := strings.HasPrefix(path, prefix) masterlog.Debug("AuthMiddleware: Prefix pattern matching", map[string]interface{}{ "path": path, "pattern": pattern, "prefix": prefix, "matches": matches, }) return matches } // Exact match matches := path == pattern masterlog.Debug("AuthMiddleware: Exact pattern matching", map[string]interface{}{ "path": path, "pattern": pattern, "matches": matches, }) return matches } // verifyArgon2Hash verifies a password against an argon2id hash func (m *AuthMiddleware) verifyArgon2Hash(password, hash string) bool { // Parse the hash format: $argon2id$v=19$m=65536,t=3,p=4$salt$encodedHash parts := strings.Split(hash, "$") if len(parts) != 6 { masterlog.Debug("AuthMiddleware: Invalid hash format", map[string]interface{}{ "parts_count": len(parts), }) return false } if parts[1] != "argon2id" { masterlog.Debug("AuthMiddleware: Unsupported hash type", map[string]interface{}{ "type": parts[1], }) return false } // Parse parameters var version int var memory, time, parallelism uint32 if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil { return false } if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, ¶llelism); err != nil { return false } // Decode salt salt, err := base64.RawStdEncoding.DecodeString(parts[4]) if err != nil { masterlog.Debug("AuthMiddleware: Failed to decode salt", map[string]interface{}{ "error": err, }) return false } // Decode stored hash decodedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) if err != nil { masterlog.Debug("AuthMiddleware: Failed to decode hash", map[string]interface{}{ "error": err, }) return false } // Generate hash for comparison hashLength := uint32(len(decodedHash)) comparisonHash := argon2.IDKey([]byte(password), salt, time, memory, uint8(parallelism), hashLength) // Use constant time comparison to prevent timing attacks return subtle.ConstantTimeCompare(comparisonHash, decodedHash) == 1 }