feat: add utility functions and helpers
- Add JWT token generation and validation utilities - Add password hashing with bcrypt for secure authentication - Add pagination helper for API responses - Add random string generation for tokens and IDs - Add session management utilities - Add admin user initialization functionality
This commit is contained in:
85
utils/hash.go
Normal file
85
utils/hash.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Argon2Params struct {
|
||||||
|
Memory uint32
|
||||||
|
Iterations uint32
|
||||||
|
Parallelism uint8
|
||||||
|
SaltLength uint32
|
||||||
|
KeyLength uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default secure parameters
|
||||||
|
var DefaultParams = &Argon2Params{
|
||||||
|
Memory: 64 * 1024, // 64 MB
|
||||||
|
Iterations: 3,
|
||||||
|
Parallelism: 2,
|
||||||
|
SaltLength: 16,
|
||||||
|
KeyLength: 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashPassword creates a secure Argon2id hash
|
||||||
|
func HashPassword(password string, p *Argon2Params) (string, error) {
|
||||||
|
// Generate salt
|
||||||
|
salt := make([]byte, p.SaltLength)
|
||||||
|
_, err := rand.Read(salt)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate hash
|
||||||
|
hash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength)
|
||||||
|
|
||||||
|
// Format: $argon2id$v=19$m=65536,t=3,p=2$<salt>$<hash>
|
||||||
|
encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
||||||
|
p.Memory, p.Iterations, p.Parallelism,
|
||||||
|
base64.RawStdEncoding.EncodeToString(salt),
|
||||||
|
base64.RawStdEncoding.EncodeToString(hash),
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyPassword checks if the password matches the hash
|
||||||
|
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||||
|
parts := strings.Split(encodedHash, "$")
|
||||||
|
if len(parts) != 6 {
|
||||||
|
return false, fmt.Errorf("invalid hash format")
|
||||||
|
}
|
||||||
|
|
||||||
|
var memory uint32
|
||||||
|
var iterations uint32
|
||||||
|
var parallelism uint8
|
||||||
|
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, ¶llelism)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("invalid hash parameters: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode salt and hash
|
||||||
|
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("invalid salt encoding: %w", err)
|
||||||
|
}
|
||||||
|
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("invalid hash encoding: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculate hash with the same parameters
|
||||||
|
hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, uint32(len(expectedHash)))
|
||||||
|
|
||||||
|
// Constant time comparison
|
||||||
|
if subtle.ConstantTimeCompare(hash, expectedHash) == 1 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
25
utils/init.go
Normal file
25
utils/init.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.secnex.io/secnex/idp-api/db"
|
||||||
|
"git.secnex.io/secnex/idp-api/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CheckAdminUser() {
|
||||||
|
db := db.GetDB()
|
||||||
|
|
||||||
|
hashPassword, err := HashPassword("admin", DefaultParams)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Failed to hash password:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := models.User{
|
||||||
|
Username: "admin",
|
||||||
|
Password: hashPassword,
|
||||||
|
Email: "admin@secnex.local",
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Create(&user)
|
||||||
|
}
|
34
utils/jwt.go
Normal file
34
utils/jwt.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateJWT(issuer string, claims jwt.MapClaims, expiresIn time.Duration) (string, error) {
|
||||||
|
claims["iss"] = os.Getenv("JWT_ISSUER")
|
||||||
|
claims["aud"] = issuer
|
||||||
|
claims["iat"] = time.Now().Unix()
|
||||||
|
claims["exp"] = time.Now().Add(expiresIn).Unix()
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString([]byte(os.Getenv("JWT_SECRET")))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenString, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyJWT(tokenString string) (jwt.MapClaims, error) {
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(os.Getenv("JWT_SECRET")), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token.Claims.(jwt.MapClaims), nil
|
||||||
|
}
|
46
utils/pagination.go
Normal file
46
utils/pagination.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Pagination(c *fiber.Ctx, total int64, page int, limit int) (*fiber.Map, *fiber.Map) {
|
||||||
|
totalPages := int(math.Ceil(float64(total) / float64(limit)))
|
||||||
|
paginationInformation := fiber.Map{
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"total": total,
|
||||||
|
"total_pages": totalPages,
|
||||||
|
}
|
||||||
|
|
||||||
|
paginationLinks := fiber.Map{
|
||||||
|
"self": fiber.Map{
|
||||||
|
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page, limit),
|
||||||
|
},
|
||||||
|
"first": fiber.Map{
|
||||||
|
"href": fmt.Sprintf("%s?page=1&limit=%d", c.Path(), limit),
|
||||||
|
},
|
||||||
|
"last": fiber.Map{
|
||||||
|
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), totalPages, limit),
|
||||||
|
},
|
||||||
|
"next": nil,
|
||||||
|
"prev": nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
if page > 1 {
|
||||||
|
paginationLinks["prev"] = fiber.Map{
|
||||||
|
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page-1, limit),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if page < totalPages {
|
||||||
|
paginationLinks["next"] = fiber.Map{
|
||||||
|
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page+1, limit),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &paginationInformation, &paginationLinks
|
||||||
|
}
|
13
utils/random.go
Normal file
13
utils/random.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateRandomString(length int) string {
|
||||||
|
bytes := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
bytes[i] = byte(rand.Intn(26) + 97)
|
||||||
|
}
|
||||||
|
return string(bytes)
|
||||||
|
}
|
94
utils/session.go
Normal file
94
utils/session.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.secnex.io/secnex/idp-api/api"
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthTypeSession AuthType = "scn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExtractBasicAuthFromHeader(header string, c *fiber.Ctx) (string, string, error) {
|
||||||
|
if header == "" {
|
||||||
|
return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||||
|
"message": "Authorization header is required",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode base64
|
||||||
|
basicDecoded, err := base64.StdEncoding.DecodeString(header)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||||
|
"message": "Invalid authorization header",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
basicSplit := strings.Split(string(basicDecoded), ":")
|
||||||
|
|
||||||
|
return basicSplit[0], basicSplit[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractSessionFromHeader(header string, c *fiber.Ctx) (string, error) {
|
||||||
|
if header == "" {
|
||||||
|
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||||
|
"message": "Authorization header is required",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionSplit := strings.Split(header, ":")
|
||||||
|
|
||||||
|
if len(sessionSplit) != 2 {
|
||||||
|
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||||
|
"message": "Invalid authorization header",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := AuthType(sessionSplit[0])
|
||||||
|
if authType != AuthTypeSession {
|
||||||
|
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||||
|
"message": "Invalid authorization header",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||||
|
"message": "Invalid authorization header",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(sessionId), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractSessionFromToken extracts session ID from a session token without Fiber context
|
||||||
|
// This is used for gRPC services where we don't have a Fiber context
|
||||||
|
func ExtractSessionFromToken(token string) (string, error) {
|
||||||
|
if token == "" {
|
||||||
|
return "", fmt.Errorf("session token is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionSplit := strings.Split(token, ":")
|
||||||
|
|
||||||
|
if len(sessionSplit) != 2 {
|
||||||
|
return "", fmt.Errorf("invalid session token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := AuthType(sessionSplit[0])
|
||||||
|
if authType != AuthTypeSession {
|
||||||
|
return "", fmt.Errorf("invalid auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid session token encoding")
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(sessionId), nil
|
||||||
|
}
|
Reference in New Issue
Block a user