feat: add utility functions and helpers

- Add JWT token generation and validation utilities
- Add password hashing with bcrypt for secure authentication
- Add pagination helper for API responses
- Add random string generation for tokens and IDs
- Add session management utilities
- Add admin user initialization functionality
This commit is contained in:
Björn Benouarets
2025-09-25 23:24:18 +02:00
parent 9a8a93061d
commit f509f6e524
6 changed files with 297 additions and 0 deletions

85
utils/hash.go Normal file
View File

@@ -0,0 +1,85 @@
package utils
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"
"golang.org/x/crypto/argon2"
)
type Argon2Params struct {
Memory uint32
Iterations uint32
Parallelism uint8
SaltLength uint32
KeyLength uint32
}
// Default secure parameters
var DefaultParams = &Argon2Params{
Memory: 64 * 1024, // 64 MB
Iterations: 3,
Parallelism: 2,
SaltLength: 16,
KeyLength: 32,
}
// HashPassword creates a secure Argon2id hash
func HashPassword(password string, p *Argon2Params) (string, error) {
// Generate salt
salt := make([]byte, p.SaltLength)
_, err := rand.Read(salt)
if err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
// Calculate hash
hash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength)
// Format: $argon2id$v=19$m=65536,t=3,p=2$<salt>$<hash>
encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
p.Memory, p.Iterations, p.Parallelism,
base64.RawStdEncoding.EncodeToString(salt),
base64.RawStdEncoding.EncodeToString(hash),
)
return encoded, nil
}
// VerifyPassword checks if the password matches the hash
func VerifyPassword(password, encodedHash string) (bool, error) {
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
var memory uint32
var iterations uint32
var parallelism uint8
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, &parallelism)
if err != nil {
return false, fmt.Errorf("invalid hash parameters: %w", err)
}
// Decode salt and hash
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false, fmt.Errorf("invalid salt encoding: %w", err)
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false, fmt.Errorf("invalid hash encoding: %w", err)
}
// Recalculate hash with the same parameters
hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, uint32(len(expectedHash)))
// Constant time comparison
if subtle.ConstantTimeCompare(hash, expectedHash) == 1 {
return true, nil
}
return false, nil
}

25
utils/init.go Normal file
View File

@@ -0,0 +1,25 @@
package utils
import (
"fmt"
"git.secnex.io/secnex/idp-api/db"
"git.secnex.io/secnex/idp-api/models"
)
func CheckAdminUser() {
db := db.GetDB()
hashPassword, err := HashPassword("admin", DefaultParams)
if err != nil {
fmt.Println("Failed to hash password:", err)
}
user := models.User{
Username: "admin",
Password: hashPassword,
Email: "admin@secnex.local",
}
db.Create(&user)
}

34
utils/jwt.go Normal file
View File

@@ -0,0 +1,34 @@
package utils
import (
"os"
"time"
"github.com/golang-jwt/jwt/v5"
)
func GenerateJWT(issuer string, claims jwt.MapClaims, expiresIn time.Duration) (string, error) {
claims["iss"] = os.Getenv("JWT_ISSUER")
claims["aud"] = issuer
claims["iat"] = time.Now().Unix()
claims["exp"] = time.Now().Add(expiresIn).Unix()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(os.Getenv("JWT_SECRET")))
if err != nil {
return "", err
}
return tokenString, nil
}
func VerifyJWT(tokenString string) (jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(os.Getenv("JWT_SECRET")), nil
})
if err != nil {
return nil, err
}
return token.Claims.(jwt.MapClaims), nil
}

46
utils/pagination.go Normal file
View File

@@ -0,0 +1,46 @@
package utils
import (
"fmt"
"math"
"github.com/gofiber/fiber/v2"
)
func Pagination(c *fiber.Ctx, total int64, page int, limit int) (*fiber.Map, *fiber.Map) {
totalPages := int(math.Ceil(float64(total) / float64(limit)))
paginationInformation := fiber.Map{
"page": page,
"limit": limit,
"total": total,
"total_pages": totalPages,
}
paginationLinks := fiber.Map{
"self": fiber.Map{
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page, limit),
},
"first": fiber.Map{
"href": fmt.Sprintf("%s?page=1&limit=%d", c.Path(), limit),
},
"last": fiber.Map{
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), totalPages, limit),
},
"next": nil,
"prev": nil,
}
if page > 1 {
paginationLinks["prev"] = fiber.Map{
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page-1, limit),
}
}
if page < totalPages {
paginationLinks["next"] = fiber.Map{
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page+1, limit),
}
}
return &paginationInformation, &paginationLinks
}

13
utils/random.go Normal file
View File

@@ -0,0 +1,13 @@
package utils
import (
"math/rand"
)
func GenerateRandomString(length int) string {
bytes := make([]byte, length)
for i := 0; i < length; i++ {
bytes[i] = byte(rand.Intn(26) + 97)
}
return string(bytes)
}

94
utils/session.go Normal file
View File

@@ -0,0 +1,94 @@
package utils
import (
"encoding/base64"
"fmt"
"strings"
"git.secnex.io/secnex/idp-api/api"
"github.com/gofiber/fiber/v2"
)
type AuthType string
const (
AuthTypeSession AuthType = "scn"
)
func ExtractBasicAuthFromHeader(header string, c *fiber.Ctx) (string, string, error) {
if header == "" {
return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
"message": "Authorization header is required",
})
}
// Decode base64
basicDecoded, err := base64.StdEncoding.DecodeString(header)
if err != nil {
return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
"message": "Invalid authorization header",
})
}
basicSplit := strings.Split(string(basicDecoded), ":")
return basicSplit[0], basicSplit[1], nil
}
func ExtractSessionFromHeader(header string, c *fiber.Ctx) (string, error) {
if header == "" {
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
"message": "Authorization header is required",
})
}
sessionSplit := strings.Split(header, ":")
if len(sessionSplit) != 2 {
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
"message": "Invalid authorization header",
})
}
authType := AuthType(sessionSplit[0])
if authType != AuthTypeSession {
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
"message": "Invalid authorization header",
})
}
sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1])
if err != nil {
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
"message": "Invalid authorization header",
})
}
return string(sessionId), nil
}
// ExtractSessionFromToken extracts session ID from a session token without Fiber context
// This is used for gRPC services where we don't have a Fiber context
func ExtractSessionFromToken(token string) (string, error) {
if token == "" {
return "", fmt.Errorf("session token is required")
}
sessionSplit := strings.Split(token, ":")
if len(sessionSplit) != 2 {
return "", fmt.Errorf("invalid session token format")
}
authType := AuthType(sessionSplit[0])
if authType != AuthTypeSession {
return "", fmt.Errorf("invalid auth type")
}
sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1])
if err != nil {
return "", fmt.Errorf("invalid session token encoding")
}
return string(sessionId), nil
}