package sql import ( "fmt" "regexp" "sort" "strings" "git.secnex.io/secnex/pgson/schema" ) var ( IdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) TypeMap = map[string]string{ "string": "VARCHAR", "int": "INTEGER", "float": "FLOAT", "bool": "BOOLEAN", "date": "DATE", "time": "TIME", "datetime": "TIMESTAMP", "uuid": "UUID", "json": "JSONB", "hash": "VARCHAR", } FkActions = map[string]struct{}{ "CASCADE": {}, "RESTRICT": {}, "SET NULL": {}, "NO ACTION": {}, "SET DEFAULT": {}, } DefaultFunctions = map[string]struct{}{ "CURRENT_TIMESTAMP": {}, "now()": {}, "uuid_generate_v4()": {}, } AlgorithmMap = map[string]struct{}{ "argon2": {}, "bcrypt": {}, "md5": {}, "sha256": {}, "sha512": {}, } ) func ValidateIdent(name string) error { if !IdentifierRegex.MatchString(name) { return fmt.Errorf("invalid identifier: %s", name) } return nil } func ValidateType(name string) (string, error) { pg, ok := TypeMap[name] if !ok { return "", fmt.Errorf("invalid type: %s", name) } return pg, nil } func ValidateAction(act string) (string, error) { a := strings.ToUpper(act) if _, ok := FkActions[a]; !ok { return "", fmt.Errorf("invalid action: %s", act) } return a, nil } func ValidateDefault(def string) (string, error) { if regexp.MustCompile(`^\d+(\.\d+)?$`).MatchString(def) || regexp.MustCompile(`^'.*'$`).MatchString(def) { return def, nil } if _, ok := DefaultFunctions[strings.ToLower(def)]; ok { return def, nil } return "", fmt.Errorf("invalid default: %s", def) } func ValidateReferences(refs *schema.Reference) error { if err := ValidateIdent(refs.Table); err != nil { return err } if err := ValidateIdent(refs.Column); err != nil { return err } return nil } func ValidateAlgorithm(algo string) error { a := strings.ToUpper(algo) if _, ok := AlgorithmMap[a]; !ok { return fmt.Errorf("invalid algorithm: %s", algo) } return nil } func ValidatePrimary(primary bool) error { if primary { return nil } return fmt.Errorf("primary is required") } func ValidateUnique(unique bool) error { if unique { return nil } return fmt.Errorf("unique is required") } func ValidateComment(comment *string) error { if comment == nil { return nil } return ValidateIdent(*comment) } func QuoteIdent(name string) string { return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` } func QuoteLiteral(s string) string { s = strings.ReplaceAll(s, "\x00", "") var b strings.Builder b.Grow(len(s) + 2) b.WriteByte('\'') for _, r := range s { switch r { case '\'': b.WriteString("''") case '\\': b.WriteString("\\\\") case '\x00': continue case '\t', '\n', '\r': b.WriteRune(r) default: if r < 0x20 { b.WriteString(fmt.Sprintf("\\%03o", r)) } else { b.WriteRune(r) } } } b.WriteByte('\'') return b.String() } func QuoteValue(v any) string { if v == nil { return "NULL" } switch val := v.(type) { case string: if matched, _ := regexp.MatchString(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, strings.ToLower(val)); matched { return QuoteLiteral(val) } return QuoteLiteral(val) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return fmt.Sprintf("%v", val) case float32, float64: return fmt.Sprintf("%v", val) case bool: if val { return "TRUE" } return "FALSE" default: return QuoteLiteral(fmt.Sprintf("%v", val)) } } func BuildWhereClause(conditions map[string]any) (string, error) { if len(conditions) == 0 { return "", nil } var clauses []string keys := make([]string, 0, len(conditions)) for k := range conditions { keys = append(keys, k) } sort.Strings(keys) for _, col := range keys { if err := ValidateIdent(col); err != nil { return "", fmt.Errorf("invalid column identifier in WHERE clause: %s", col) } value := conditions[col] if value == nil { clauses = append(clauses, fmt.Sprintf("%s IS NULL", QuoteIdent(col))) } else { clauses = append(clauses, fmt.Sprintf("%s = %s", QuoteIdent(col), QuoteValue(value))) } } return " WHERE " + strings.Join(clauses, " AND "), nil }