diff --git a/utils/hash.go b/utils/hash.go new file mode 100644 index 0000000..a748e6e --- /dev/null +++ b/utils/hash.go @@ -0,0 +1,85 @@ +package utils + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +type Argon2Params struct { + Memory uint32 + Iterations uint32 + Parallelism uint8 + SaltLength uint32 + KeyLength uint32 +} + +// Default secure parameters +var DefaultParams = &Argon2Params{ + Memory: 64 * 1024, // 64 MB + Iterations: 3, + Parallelism: 2, + SaltLength: 16, + KeyLength: 32, +} + +// HashPassword creates a secure Argon2id hash +func HashPassword(password string, p *Argon2Params) (string, error) { + // Generate salt + salt := make([]byte, p.SaltLength) + _, err := rand.Read(salt) + if err != nil { + return "", fmt.Errorf("failed to generate salt: %w", err) + } + + // Calculate hash + hash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength) + + // Format: $argon2id$v=19$m=65536,t=3,p=2$$ + encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", + p.Memory, p.Iterations, p.Parallelism, + base64.RawStdEncoding.EncodeToString(salt), + base64.RawStdEncoding.EncodeToString(hash), + ) + + return encoded, nil +} + +// VerifyPassword checks if the password matches the hash +func VerifyPassword(password, encodedHash string) (bool, error) { + parts := strings.Split(encodedHash, "$") + if len(parts) != 6 { + return false, fmt.Errorf("invalid hash format") + } + + var memory uint32 + var iterations uint32 + var parallelism uint8 + _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, ¶llelism) + if err != nil { + return false, fmt.Errorf("invalid hash parameters: %w", err) + } + + // Decode salt and hash + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return false, fmt.Errorf("invalid salt encoding: %w", err) + } + expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return false, fmt.Errorf("invalid hash encoding: %w", err) + } + + // Recalculate hash with the same parameters + hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, uint32(len(expectedHash))) + + // Constant time comparison + if subtle.ConstantTimeCompare(hash, expectedHash) == 1 { + return true, nil + } + return false, nil +} diff --git a/utils/init.go b/utils/init.go new file mode 100644 index 0000000..1c413d3 --- /dev/null +++ b/utils/init.go @@ -0,0 +1,25 @@ +package utils + +import ( + "fmt" + + "git.secnex.io/secnex/idp-api/db" + "git.secnex.io/secnex/idp-api/models" +) + +func CheckAdminUser() { + db := db.GetDB() + + hashPassword, err := HashPassword("admin", DefaultParams) + if err != nil { + fmt.Println("Failed to hash password:", err) + } + + user := models.User{ + Username: "admin", + Password: hashPassword, + Email: "admin@secnex.local", + } + + db.Create(&user) +} diff --git a/utils/jwt.go b/utils/jwt.go new file mode 100644 index 0000000..6b66efa --- /dev/null +++ b/utils/jwt.go @@ -0,0 +1,34 @@ +package utils + +import ( + "os" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func GenerateJWT(issuer string, claims jwt.MapClaims, expiresIn time.Duration) (string, error) { + claims["iss"] = os.Getenv("JWT_ISSUER") + claims["aud"] = issuer + claims["iat"] = time.Now().Unix() + claims["exp"] = time.Now().Add(expiresIn).Unix() + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + tokenString, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) + if err != nil { + return "", err + } + + return tokenString, nil +} + +func VerifyJWT(tokenString string) (jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte(os.Getenv("JWT_SECRET")), nil + }) + if err != nil { + return nil, err + } + + return token.Claims.(jwt.MapClaims), nil +} diff --git a/utils/pagination.go b/utils/pagination.go new file mode 100644 index 0000000..fddc6ac --- /dev/null +++ b/utils/pagination.go @@ -0,0 +1,46 @@ +package utils + +import ( + "fmt" + "math" + + "github.com/gofiber/fiber/v2" +) + +func Pagination(c *fiber.Ctx, total int64, page int, limit int) (*fiber.Map, *fiber.Map) { + totalPages := int(math.Ceil(float64(total) / float64(limit))) + paginationInformation := fiber.Map{ + "page": page, + "limit": limit, + "total": total, + "total_pages": totalPages, + } + + paginationLinks := fiber.Map{ + "self": fiber.Map{ + "href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page, limit), + }, + "first": fiber.Map{ + "href": fmt.Sprintf("%s?page=1&limit=%d", c.Path(), limit), + }, + "last": fiber.Map{ + "href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), totalPages, limit), + }, + "next": nil, + "prev": nil, + } + + if page > 1 { + paginationLinks["prev"] = fiber.Map{ + "href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page-1, limit), + } + } + + if page < totalPages { + paginationLinks["next"] = fiber.Map{ + "href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page+1, limit), + } + } + + return &paginationInformation, &paginationLinks +} diff --git a/utils/random.go b/utils/random.go new file mode 100644 index 0000000..eb6b885 --- /dev/null +++ b/utils/random.go @@ -0,0 +1,13 @@ +package utils + +import ( + "math/rand" +) + +func GenerateRandomString(length int) string { + bytes := make([]byte, length) + for i := 0; i < length; i++ { + bytes[i] = byte(rand.Intn(26) + 97) + } + return string(bytes) +} diff --git a/utils/session.go b/utils/session.go new file mode 100644 index 0000000..ed5b135 --- /dev/null +++ b/utils/session.go @@ -0,0 +1,94 @@ +package utils + +import ( + "encoding/base64" + "fmt" + "strings" + + "git.secnex.io/secnex/idp-api/api" + "github.com/gofiber/fiber/v2" +) + +type AuthType string + +const ( + AuthTypeSession AuthType = "scn" +) + +func ExtractBasicAuthFromHeader(header string, c *fiber.Ctx) (string, string, error) { + if header == "" { + return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Authorization header is required", + }) + } + + // Decode base64 + basicDecoded, err := base64.StdEncoding.DecodeString(header) + if err != nil { + return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + basicSplit := strings.Split(string(basicDecoded), ":") + + return basicSplit[0], basicSplit[1], nil +} + +func ExtractSessionFromHeader(header string, c *fiber.Ctx) (string, error) { + if header == "" { + return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Authorization header is required", + }) + } + + sessionSplit := strings.Split(header, ":") + + if len(sessionSplit) != 2 { + return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + authType := AuthType(sessionSplit[0]) + if authType != AuthTypeSession { + return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1]) + if err != nil { + return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{ + "message": "Invalid authorization header", + }) + } + + return string(sessionId), nil +} + +// ExtractSessionFromToken extracts session ID from a session token without Fiber context +// This is used for gRPC services where we don't have a Fiber context +func ExtractSessionFromToken(token string) (string, error) { + if token == "" { + return "", fmt.Errorf("session token is required") + } + + sessionSplit := strings.Split(token, ":") + + if len(sessionSplit) != 2 { + return "", fmt.Errorf("invalid session token format") + } + + authType := AuthType(sessionSplit[0]) + if authType != AuthTypeSession { + return "", fmt.Errorf("invalid auth type") + } + + sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1]) + if err != nil { + return "", fmt.Errorf("invalid session token encoding") + } + + return string(sessionId), nil +}