Files
pgson/sql/validate.go

204 lines
4.1 KiB
Go

package sql
import (
"fmt"
"regexp"
"sort"
"strings"
"git.secnex.io/secnex/pgson/schema"
)
var (
IdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
TypeMap = map[string]string{
"string": "VARCHAR",
"int": "INTEGER",
"float": "FLOAT",
"bool": "BOOLEAN",
"date": "DATE",
"time": "TIME",
"datetime": "TIMESTAMP",
"uuid": "UUID",
"json": "JSONB",
"hash": "VARCHAR",
}
FkActions = map[string]struct{}{
"CASCADE": {},
"RESTRICT": {},
"SET NULL": {},
"NO ACTION": {},
"SET DEFAULT": {},
}
DefaultFunctions = map[string]struct{}{
"CURRENT_TIMESTAMP": {},
"now()": {},
"uuid_generate_v4()": {},
}
AlgorithmMap = map[string]struct{}{
"argon2": {},
"bcrypt": {},
"md5": {},
"sha256": {},
"sha512": {},
}
)
func ValidateIdent(name string) error {
if !IdentifierRegex.MatchString(name) {
return fmt.Errorf("invalid identifier: %s", name)
}
return nil
}
func ValidateType(name string) (string, error) {
pg, ok := TypeMap[name]
if !ok {
return "", fmt.Errorf("invalid type: %s", name)
}
return pg, nil
}
func ValidateAction(act string) (string, error) {
a := strings.ToUpper(act)
if _, ok := FkActions[a]; !ok {
return "", fmt.Errorf("invalid action: %s", act)
}
return a, nil
}
func ValidateDefault(def string) (string, error) {
if regexp.MustCompile(`^\d+(\.\d+)?$`).MatchString(def) || regexp.MustCompile(`^'.*'$`).MatchString(def) {
return def, nil
}
if _, ok := DefaultFunctions[strings.ToLower(def)]; ok {
return def, nil
}
return "", fmt.Errorf("invalid default: %s", def)
}
func ValidateReferences(refs *schema.Reference) error {
if err := ValidateIdent(refs.Table); err != nil {
return err
}
if err := ValidateIdent(refs.Column); err != nil {
return err
}
return nil
}
func ValidateAlgorithm(algo string) error {
a := strings.ToUpper(algo)
if _, ok := AlgorithmMap[a]; !ok {
return fmt.Errorf("invalid algorithm: %s", algo)
}
return nil
}
func ValidatePrimary(primary bool) error {
if primary {
return nil
}
return fmt.Errorf("primary is required")
}
func ValidateUnique(unique bool) error {
if unique {
return nil
}
return fmt.Errorf("unique is required")
}
func ValidateComment(comment *string) error {
if comment == nil {
return nil
}
return ValidateIdent(*comment)
}
func QuoteIdent(name string) string {
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
}
func QuoteLiteral(s string) string {
s = strings.ReplaceAll(s, "\x00", "")
var b strings.Builder
b.Grow(len(s) + 2)
b.WriteByte('\'')
for _, r := range s {
switch r {
case '\'':
b.WriteString("''")
case '\\':
b.WriteString("\\\\")
case '\x00':
continue
case '\t', '\n', '\r':
b.WriteRune(r)
default:
if r < 0x20 {
b.WriteString(fmt.Sprintf("\\%03o", r))
} else {
b.WriteRune(r)
}
}
}
b.WriteByte('\'')
return b.String()
}
func QuoteValue(v any) string {
if v == nil {
return "NULL"
}
switch val := v.(type) {
case string:
if matched, _ := regexp.MatchString(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, strings.ToLower(val)); matched {
return QuoteLiteral(val)
}
return QuoteLiteral(val)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%v", val)
case float32, float64:
return fmt.Sprintf("%v", val)
case bool:
if val {
return "TRUE"
}
return "FALSE"
default:
return QuoteLiteral(fmt.Sprintf("%v", val))
}
}
func BuildWhereClause(conditions map[string]any) (string, error) {
if len(conditions) == 0 {
return "", nil
}
var clauses []string
keys := make([]string, 0, len(conditions))
for k := range conditions {
keys = append(keys, k)
}
sort.Strings(keys)
for _, col := range keys {
if err := ValidateIdent(col); err != nil {
return "", fmt.Errorf("invalid column identifier in WHERE clause: %s", col)
}
value := conditions[col]
if value == nil {
clauses = append(clauses, fmt.Sprintf("%s IS NULL", QuoteIdent(col)))
} else {
clauses = append(clauses, fmt.Sprintf("%s = %s", QuoteIdent(col), QuoteValue(value)))
}
}
return " WHERE " + strings.Join(clauses, " AND "), nil
}