204 lines
4.1 KiB
Go
204 lines
4.1 KiB
Go
package sql
|
|
|
|
import (
|
|
"fmt"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
|
|
"git.secnex.io/secnex/pgson/schema"
|
|
)
|
|
|
|
var (
|
|
IdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
|
TypeMap = map[string]string{
|
|
"string": "VARCHAR",
|
|
"int": "INTEGER",
|
|
"float": "FLOAT",
|
|
"bool": "BOOLEAN",
|
|
"date": "DATE",
|
|
"time": "TIME",
|
|
"datetime": "TIMESTAMP",
|
|
"uuid": "UUID",
|
|
"json": "JSONB",
|
|
"hash": "VARCHAR",
|
|
}
|
|
FkActions = map[string]struct{}{
|
|
"CASCADE": {},
|
|
"RESTRICT": {},
|
|
"SET NULL": {},
|
|
"NO ACTION": {},
|
|
"SET DEFAULT": {},
|
|
}
|
|
DefaultFunctions = map[string]struct{}{
|
|
"CURRENT_TIMESTAMP": {},
|
|
"now()": {},
|
|
"uuid_generate_v4()": {},
|
|
}
|
|
AlgorithmMap = map[string]struct{}{
|
|
"argon2": {},
|
|
"bcrypt": {},
|
|
"md5": {},
|
|
"sha256": {},
|
|
"sha512": {},
|
|
}
|
|
)
|
|
|
|
func ValidateIdent(name string) error {
|
|
if !IdentifierRegex.MatchString(name) {
|
|
return fmt.Errorf("invalid identifier: %s", name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateType(name string) (string, error) {
|
|
pg, ok := TypeMap[name]
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid type: %s", name)
|
|
}
|
|
return pg, nil
|
|
}
|
|
|
|
func ValidateAction(act string) (string, error) {
|
|
a := strings.ToUpper(act)
|
|
if _, ok := FkActions[a]; !ok {
|
|
return "", fmt.Errorf("invalid action: %s", act)
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
func ValidateDefault(def string) (string, error) {
|
|
if regexp.MustCompile(`^\d+(\.\d+)?$`).MatchString(def) || regexp.MustCompile(`^'.*'$`).MatchString(def) {
|
|
return def, nil
|
|
}
|
|
if _, ok := DefaultFunctions[strings.ToLower(def)]; ok {
|
|
return def, nil
|
|
}
|
|
return "", fmt.Errorf("invalid default: %s", def)
|
|
}
|
|
|
|
func ValidateReferences(refs *schema.Reference) error {
|
|
if err := ValidateIdent(refs.Table); err != nil {
|
|
return err
|
|
}
|
|
if err := ValidateIdent(refs.Column); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateAlgorithm(algo string) error {
|
|
a := strings.ToUpper(algo)
|
|
if _, ok := AlgorithmMap[a]; !ok {
|
|
return fmt.Errorf("invalid algorithm: %s", algo)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidatePrimary(primary bool) error {
|
|
if primary {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("primary is required")
|
|
}
|
|
|
|
func ValidateUnique(unique bool) error {
|
|
if unique {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("unique is required")
|
|
}
|
|
|
|
func ValidateComment(comment *string) error {
|
|
if comment == nil {
|
|
return nil
|
|
}
|
|
return ValidateIdent(*comment)
|
|
}
|
|
|
|
func QuoteIdent(name string) string {
|
|
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
|
|
}
|
|
|
|
func QuoteLiteral(s string) string {
|
|
s = strings.ReplaceAll(s, "\x00", "")
|
|
|
|
var b strings.Builder
|
|
b.Grow(len(s) + 2)
|
|
b.WriteByte('\'')
|
|
|
|
for _, r := range s {
|
|
switch r {
|
|
case '\'':
|
|
b.WriteString("''")
|
|
case '\\':
|
|
b.WriteString("\\\\")
|
|
case '\x00':
|
|
continue
|
|
case '\t', '\n', '\r':
|
|
b.WriteRune(r)
|
|
default:
|
|
if r < 0x20 {
|
|
b.WriteString(fmt.Sprintf("\\%03o", r))
|
|
} else {
|
|
b.WriteRune(r)
|
|
}
|
|
}
|
|
}
|
|
|
|
b.WriteByte('\'')
|
|
return b.String()
|
|
}
|
|
|
|
func QuoteValue(v any) string {
|
|
if v == nil {
|
|
return "NULL"
|
|
}
|
|
|
|
switch val := v.(type) {
|
|
case string:
|
|
if matched, _ := regexp.MatchString(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, strings.ToLower(val)); matched {
|
|
return QuoteLiteral(val)
|
|
}
|
|
return QuoteLiteral(val)
|
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
|
return fmt.Sprintf("%v", val)
|
|
case float32, float64:
|
|
return fmt.Sprintf("%v", val)
|
|
case bool:
|
|
if val {
|
|
return "TRUE"
|
|
}
|
|
return "FALSE"
|
|
default:
|
|
return QuoteLiteral(fmt.Sprintf("%v", val))
|
|
}
|
|
}
|
|
|
|
func BuildWhereClause(conditions map[string]any) (string, error) {
|
|
if len(conditions) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
var clauses []string
|
|
keys := make([]string, 0, len(conditions))
|
|
for k := range conditions {
|
|
keys = append(keys, k)
|
|
}
|
|
sort.Strings(keys)
|
|
|
|
for _, col := range keys {
|
|
if err := ValidateIdent(col); err != nil {
|
|
return "", fmt.Errorf("invalid column identifier in WHERE clause: %s", col)
|
|
}
|
|
value := conditions[col]
|
|
if value == nil {
|
|
clauses = append(clauses, fmt.Sprintf("%s IS NULL", QuoteIdent(col)))
|
|
} else {
|
|
clauses = append(clauses, fmt.Sprintf("%s = %s", QuoteIdent(col), QuoteValue(value)))
|
|
}
|
|
}
|
|
|
|
return " WHERE " + strings.Join(clauses, " AND "), nil
|
|
}
|