package middlewares import ( "encoding/base64" "strings" "time" "git.secnex.io/secnex/idp-api/api" "git.secnex.io/secnex/idp-api/db" "git.secnex.io/secnex/idp-api/repositories" "git.secnex.io/secnex/idp-api/utils" "github.com/gofiber/fiber/v2" ) func BasicAuthMiddleware() fiber.Handler { return func(c *fiber.Ctx) error { header := c.Get("Authorization") basic := strings.Split(header, " ")[1] username, password, err := utils.ExtractBasicAuthFromHeader(basic, c) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } if username == "" || password == "" { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } userRepo := repositories.NewUserRepository(db.GetDB()) user, err := userRepo.GetUserByUsername(username) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } passwordMatch, err := utils.VerifyPassword(password, user.Password) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } if !passwordMatch { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } c.Locals("user", user) return c.Next() } } func AuthMiddleware() fiber.Handler { return func(c *fiber.Ctx) error { header := c.Get("Authorization") bearer := strings.Split(header, " ")[1] sessionId, err := utils.ExtractSessionFromHeader(bearer, c) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } if sessionId == "" { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } sessionRepo := repositories.NewSessionRepository(db.GetDB()) session, err := sessionRepo.GetSessionByID(string(sessionId), true) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } c.Locals("session", session) return c.Next() } } func ApiKeyMiddleware() fiber.Handler { return func(c *fiber.Ctx) error { header := c.Get("Authorization") if header == "" { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Authorization header is required", }) } apiKeySplit := strings.Split(header, " ") if len(apiKeySplit) != 2 { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } apiKeySplitPlatform := strings.Split(apiKeySplit[1], ":") if len(apiKeySplitPlatform) != 2 { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } apiKeyToken, err := base64.StdEncoding.DecodeString(apiKeySplitPlatform[1]) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } apiKeyTokenSplit := strings.Split(string(apiKeyToken), ":") if len(apiKeyTokenSplit) != 2 { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } apiKeyID := apiKeyTokenSplit[0] apiKeyKey := apiKeyTokenSplit[1] apiKeyRepo := repositories.NewApiKeyRepository(db.GetDB()) apiKey, err := apiKeyRepo.GetApiKeyByID(apiKeyID) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } passwordMatch, err := utils.VerifyPassword(apiKeyKey, apiKey.Key) if err != nil { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } if !passwordMatch { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } if apiKey.Revoked { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ "message": "Invalid authorization header", }) } if apiKey.ExpiresAt != nil && apiKey.ExpiresAt.Before(time.Now()) { return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) } c.Locals("user", apiKey.UserID) return c.Next() } }