feat: add utility functions and helpers
- Add JWT token generation and validation utilities - Add password hashing with bcrypt for secure authentication - Add pagination helper for API responses - Add random string generation for tokens and IDs - Add session management utilities - Add admin user initialization functionality
This commit is contained in:
85
utils/hash.go
Normal file
85
utils/hash.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
type Argon2Params struct {
|
||||
Memory uint32
|
||||
Iterations uint32
|
||||
Parallelism uint8
|
||||
SaltLength uint32
|
||||
KeyLength uint32
|
||||
}
|
||||
|
||||
// Default secure parameters
|
||||
var DefaultParams = &Argon2Params{
|
||||
Memory: 64 * 1024, // 64 MB
|
||||
Iterations: 3,
|
||||
Parallelism: 2,
|
||||
SaltLength: 16,
|
||||
KeyLength: 32,
|
||||
}
|
||||
|
||||
// HashPassword creates a secure Argon2id hash
|
||||
func HashPassword(password string, p *Argon2Params) (string, error) {
|
||||
// Generate salt
|
||||
salt := make([]byte, p.SaltLength)
|
||||
_, err := rand.Read(salt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
// Calculate hash
|
||||
hash := argon2.IDKey([]byte(password), salt, p.Iterations, p.Memory, p.Parallelism, p.KeyLength)
|
||||
|
||||
// Format: $argon2id$v=19$m=65536,t=3,p=2$<salt>$<hash>
|
||||
encoded := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
||||
p.Memory, p.Iterations, p.Parallelism,
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(hash),
|
||||
)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks if the password matches the hash
|
||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return false, fmt.Errorf("invalid hash format")
|
||||
}
|
||||
|
||||
var memory uint32
|
||||
var iterations uint32
|
||||
var parallelism uint8
|
||||
_, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &iterations, ¶llelism)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid hash parameters: %w", err)
|
||||
}
|
||||
|
||||
// Decode salt and hash
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid salt encoding: %w", err)
|
||||
}
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid hash encoding: %w", err)
|
||||
}
|
||||
|
||||
// Recalculate hash with the same parameters
|
||||
hash := argon2.IDKey([]byte(password), salt, iterations, memory, parallelism, uint32(len(expectedHash)))
|
||||
|
||||
// Constant time comparison
|
||||
if subtle.ConstantTimeCompare(hash, expectedHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
25
utils/init.go
Normal file
25
utils/init.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.secnex.io/secnex/idp-api/db"
|
||||
"git.secnex.io/secnex/idp-api/models"
|
||||
)
|
||||
|
||||
func CheckAdminUser() {
|
||||
db := db.GetDB()
|
||||
|
||||
hashPassword, err := HashPassword("admin", DefaultParams)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to hash password:", err)
|
||||
}
|
||||
|
||||
user := models.User{
|
||||
Username: "admin",
|
||||
Password: hashPassword,
|
||||
Email: "admin@secnex.local",
|
||||
}
|
||||
|
||||
db.Create(&user)
|
||||
}
|
34
utils/jwt.go
Normal file
34
utils/jwt.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func GenerateJWT(issuer string, claims jwt.MapClaims, expiresIn time.Duration) (string, error) {
|
||||
claims["iss"] = os.Getenv("JWT_ISSUER")
|
||||
claims["aud"] = issuer
|
||||
claims["iat"] = time.Now().Unix()
|
||||
claims["exp"] = time.Now().Add(expiresIn).Unix()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
tokenString, err := token.SignedString([]byte(os.Getenv("JWT_SECRET")))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
func VerifyJWT(tokenString string) (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(os.Getenv("JWT_SECRET")), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token.Claims.(jwt.MapClaims), nil
|
||||
}
|
46
utils/pagination.go
Normal file
46
utils/pagination.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func Pagination(c *fiber.Ctx, total int64, page int, limit int) (*fiber.Map, *fiber.Map) {
|
||||
totalPages := int(math.Ceil(float64(total) / float64(limit)))
|
||||
paginationInformation := fiber.Map{
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"total_pages": totalPages,
|
||||
}
|
||||
|
||||
paginationLinks := fiber.Map{
|
||||
"self": fiber.Map{
|
||||
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page, limit),
|
||||
},
|
||||
"first": fiber.Map{
|
||||
"href": fmt.Sprintf("%s?page=1&limit=%d", c.Path(), limit),
|
||||
},
|
||||
"last": fiber.Map{
|
||||
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), totalPages, limit),
|
||||
},
|
||||
"next": nil,
|
||||
"prev": nil,
|
||||
}
|
||||
|
||||
if page > 1 {
|
||||
paginationLinks["prev"] = fiber.Map{
|
||||
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page-1, limit),
|
||||
}
|
||||
}
|
||||
|
||||
if page < totalPages {
|
||||
paginationLinks["next"] = fiber.Map{
|
||||
"href": fmt.Sprintf("%s?page=%d&limit=%d", c.Path(), page+1, limit),
|
||||
}
|
||||
}
|
||||
|
||||
return &paginationInformation, &paginationLinks
|
||||
}
|
13
utils/random.go
Normal file
13
utils/random.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
func GenerateRandomString(length int) string {
|
||||
bytes := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
bytes[i] = byte(rand.Intn(26) + 97)
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
94
utils/session.go
Normal file
94
utils/session.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.secnex.io/secnex/idp-api/api"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type AuthType string
|
||||
|
||||
const (
|
||||
AuthTypeSession AuthType = "scn"
|
||||
)
|
||||
|
||||
func ExtractBasicAuthFromHeader(header string, c *fiber.Ctx) (string, string, error) {
|
||||
if header == "" {
|
||||
return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||
"message": "Authorization header is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Decode base64
|
||||
basicDecoded, err := base64.StdEncoding.DecodeString(header)
|
||||
if err != nil {
|
||||
return "", "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||
"message": "Invalid authorization header",
|
||||
})
|
||||
}
|
||||
|
||||
basicSplit := strings.Split(string(basicDecoded), ":")
|
||||
|
||||
return basicSplit[0], basicSplit[1], nil
|
||||
}
|
||||
|
||||
func ExtractSessionFromHeader(header string, c *fiber.Ctx) (string, error) {
|
||||
if header == "" {
|
||||
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||
"message": "Authorization header is required",
|
||||
})
|
||||
}
|
||||
|
||||
sessionSplit := strings.Split(header, ":")
|
||||
|
||||
if len(sessionSplit) != 2 {
|
||||
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||
"message": "Invalid authorization header",
|
||||
})
|
||||
}
|
||||
|
||||
authType := AuthType(sessionSplit[0])
|
||||
if authType != AuthTypeSession {
|
||||
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||
"message": "Invalid authorization header",
|
||||
})
|
||||
}
|
||||
|
||||
sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1])
|
||||
if err != nil {
|
||||
return "", api.Error(c, "Unauthorized", fiber.StatusUnauthorized, fiber.Map{
|
||||
"message": "Invalid authorization header",
|
||||
})
|
||||
}
|
||||
|
||||
return string(sessionId), nil
|
||||
}
|
||||
|
||||
// ExtractSessionFromToken extracts session ID from a session token without Fiber context
|
||||
// This is used for gRPC services where we don't have a Fiber context
|
||||
func ExtractSessionFromToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", fmt.Errorf("session token is required")
|
||||
}
|
||||
|
||||
sessionSplit := strings.Split(token, ":")
|
||||
|
||||
if len(sessionSplit) != 2 {
|
||||
return "", fmt.Errorf("invalid session token format")
|
||||
}
|
||||
|
||||
authType := AuthType(sessionSplit[0])
|
||||
if authType != AuthTypeSession {
|
||||
return "", fmt.Errorf("invalid auth type")
|
||||
}
|
||||
|
||||
sessionId, err := base64.StdEncoding.DecodeString(sessionSplit[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid session token encoding")
|
||||
}
|
||||
|
||||
return string(sessionId), nil
|
||||
}
|
Reference in New Issue
Block a user