diff --git a/middlewares/auth.go b/middlewares/auth.go new file mode 100644 index 0000000..954fc6f --- /dev/null +++ b/middlewares/auth.go @@ -0,0 +1,139 @@ +package middlewares + +import ( + "encoding/base64" + "strings" + "time" + + "git.secnex.io/secnex/idp-api/api" + "git.secnex.io/secnex/idp-api/db" + "git.secnex.io/secnex/idp-api/repositories" + "git.secnex.io/secnex/idp-api/utils" + "github.com/gofiber/fiber/v2" +) + +func BasicAuthMiddleware() fiber.Handler { + return func(c *fiber.Ctx) error { + header := c.Get("Authorization") + basic := strings.Split(header, " ")[1] + username, password, err := utils.ExtractBasicAuthFromHeader(basic, c) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + + if username == "" || password == "" { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + userRepo := repositories.NewUserRepository(db.GetDB()) + user, err := userRepo.GetUserByUsername(username) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + passwordMatch, err := utils.VerifyPassword(password, user.Password) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + if !passwordMatch { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + + c.Locals("user", user) + + return c.Next() + } +} + +func AuthMiddleware() fiber.Handler { + return func(c *fiber.Ctx) error { + header := c.Get("Authorization") + bearer := strings.Split(header, " ")[1] + sessionId, err := utils.ExtractSessionFromHeader(bearer, c) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + + if sessionId == "" { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + + sessionRepo := repositories.NewSessionRepository(db.GetDB()) + + session, err := sessionRepo.GetSessionByID(string(sessionId), true) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + c.Locals("session", session) + return c.Next() + } +} + +func ApiKeyMiddleware() fiber.Handler { + return func(c *fiber.Ctx) error { + header := c.Get("Authorization") + if header == "" { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Authorization header is required", + }) + } + + apiKeySplit := strings.Split(header, " ") + if len(apiKeySplit) != 2 { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + apiKeySplitPlatform := strings.Split(apiKeySplit[1], ":") + if len(apiKeySplitPlatform) != 2 { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + apiKeyToken, err := base64.StdEncoding.DecodeString(apiKeySplitPlatform[1]) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + apiKeyTokenSplit := strings.Split(string(apiKeyToken), ":") + if len(apiKeyTokenSplit) != 2 { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + apiKeyID := apiKeyTokenSplit[0] + apiKeyKey := apiKeyTokenSplit[1] + apiKeyRepo := repositories.NewApiKeyRepository(db.GetDB()) + apiKey, err := apiKeyRepo.GetApiKeyByID(apiKeyID) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + passwordMatch, err := utils.VerifyPassword(apiKeyKey, apiKey.Key) + if err != nil { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + if !passwordMatch { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + if apiKey.Revoked { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + if apiKey.ExpiresAt != nil && apiKey.ExpiresAt.Before(time.Now()) { + return api.Error(c, "Unauthorized", fiber.StatusUnauthorized, nil) + } + + c.Locals("user", apiKey.UserID) + return c.Next() + } +} diff --git a/middlewares/database.go b/middlewares/database.go new file mode 100644 index 0000000..72a4697 --- /dev/null +++ b/middlewares/database.go @@ -0,0 +1,29 @@ +package middlewares + +import ( + "git.secnex.io/secnex/idp-api/db" + "github.com/gofiber/fiber/v2" + "gorm.io/gorm" +) + +// DatabaseMiddleware injects the database connection into the fiber context +func DatabaseMiddleware() fiber.Handler { + return func(c *fiber.Ctx) error { + // Get the global database instance + database := db.GetDB() + + // Store the database connection in the context + c.Locals("db", database) + + return c.Next() + } +} + +// GetDBFromContext retrieves the database connection from the fiber context +func GetDBFromContext(c *fiber.Ctx) *gorm.DB { + if db, ok := c.Locals("db").(*gorm.DB); ok { + return db + } + // Fallback to global instance if not found in context + return db.GetDB() +}