feat(sql): SQL Injection
This commit is contained in:
@@ -7,10 +7,58 @@ import (
|
||||
"git.secnex.io/secnex/pgson/utils"
|
||||
)
|
||||
|
||||
func DeleteSQL(s *schema.Table, where string) (string, error) {
|
||||
return fmt.Sprintf("UPDATE %s SET deleted_at = CURRENT_TIMESTAMP WHERE %s = %s", utils.SQLQuoteIdent(s.Name), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
|
||||
func DeleteSQL(s *schema.Table, where any) (string, error) {
|
||||
if s == nil {
|
||||
return "", fmt.Errorf("nil table provided")
|
||||
}
|
||||
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
|
||||
return "", fmt.Errorf("invalid table name: %q", s.Name)
|
||||
}
|
||||
if s.PrimaryKey == "" || !utils.IsValidIdentifier(s.PrimaryKey) {
|
||||
return "", fmt.Errorf("invalid primary key: %q", s.PrimaryKey)
|
||||
}
|
||||
found := false
|
||||
for i := range s.Schema {
|
||||
if s.Schema[i].Name == s.PrimaryKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return "", fmt.Errorf("primary key %q not found in schema", s.PrimaryKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("UPDATE %s SET deleted_at = CURRENT_TIMESTAMP WHERE %s = %s",
|
||||
utils.SQLQuoteIdent(s.Name),
|
||||
utils.SQLQuoteIdent(s.PrimaryKey),
|
||||
utils.SQLQuoteValue(where),
|
||||
), nil
|
||||
}
|
||||
|
||||
func HardDeleteSQL(s *schema.Table, where string) (string, error) {
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE %s = %s", utils.SQLQuoteIdent(s.Name), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
|
||||
func HardDeleteSQL(s *schema.Table, where any) (string, error) {
|
||||
if s == nil {
|
||||
return "", fmt.Errorf("nil table provided")
|
||||
}
|
||||
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
|
||||
return "", fmt.Errorf("invalid table name: %q", s.Name)
|
||||
}
|
||||
if s.PrimaryKey == "" || !utils.IsValidIdentifier(s.PrimaryKey) {
|
||||
return "", fmt.Errorf("invalid primary key: %q", s.PrimaryKey)
|
||||
}
|
||||
found := false
|
||||
for i := range s.Schema {
|
||||
if s.Schema[i].Name == s.PrimaryKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return "", fmt.Errorf("primary key %q not found in schema", s.PrimaryKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("DELETE FROM %s WHERE %s = %s",
|
||||
utils.SQLQuoteIdent(s.Name),
|
||||
utils.SQLQuoteIdent(s.PrimaryKey),
|
||||
utils.SQLQuoteValue(where),
|
||||
), nil
|
||||
}
|
||||
|
||||
@@ -8,5 +8,11 @@ import (
|
||||
)
|
||||
|
||||
func DropSQL(s *schema.Table) (string, error) {
|
||||
if s == nil {
|
||||
return "", fmt.Errorf("nil table provided")
|
||||
}
|
||||
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
|
||||
return "", fmt.Errorf("invalid table name: %q", s.Name)
|
||||
}
|
||||
return fmt.Sprintf("DROP TABLE IF EXISTS %s", utils.SQLQuoteIdent(s.Name)), nil
|
||||
}
|
||||
|
||||
@@ -9,19 +9,27 @@ import (
|
||||
)
|
||||
|
||||
func InsertManySQL(s *schema.Table, data []map[string]any, returning bool) (string, error) {
|
||||
// Keep unquoted column names for data access
|
||||
if s == nil || len(data) == 0 {
|
||||
return "", fmt.Errorf("invalid input: no table or data provided")
|
||||
}
|
||||
|
||||
if !utils.IsValidIdentifier(s.Name) {
|
||||
return "", fmt.Errorf("invalid table name: %q", s.Name)
|
||||
}
|
||||
|
||||
columnNames := make([]string, 0, len(data[0]))
|
||||
for column := range data[0] {
|
||||
if !utils.IsValidIdentifier(column) {
|
||||
return "", fmt.Errorf("invalid column name: %q", column)
|
||||
}
|
||||
columnNames = append(columnNames, column)
|
||||
}
|
||||
|
||||
// Create quoted column names for SQL
|
||||
columns := make([]string, len(columnNames))
|
||||
for i, col := range columnNames {
|
||||
columns[i] = utils.SQLQuoteIdent(col)
|
||||
}
|
||||
|
||||
// Create a map for quick field lookup
|
||||
fieldMap := make(map[string]*schema.Field)
|
||||
for i := range s.Schema {
|
||||
fieldMap[s.Schema[i].Name] = &s.Schema[i]
|
||||
@@ -37,7 +45,7 @@ func InsertManySQL(s *schema.Table, data []map[string]any, returning bool) (stri
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
hashed, err := utils.Hash(valueStr, *field.Algorithm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("hashing error for column %q: %w", colName, err)
|
||||
}
|
||||
value = hashed
|
||||
}
|
||||
@@ -49,9 +57,12 @@ func InsertManySQL(s *schema.Table, data []map[string]any, returning bool) (stri
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", utils.SQLQuoteIdent(s.Name), strings.Join(columns, ", "), strings.Join(values, ", "))
|
||||
if returning {
|
||||
// RETURNING the primary key
|
||||
if !utils.IsValidIdentifier(s.PrimaryKey) {
|
||||
return "", fmt.Errorf("invalid primary key column: %q", s.PrimaryKey)
|
||||
}
|
||||
query += " RETURNING " + utils.SQLQuoteIdent(s.PrimaryKey)
|
||||
}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,12 +8,19 @@ import (
|
||||
)
|
||||
|
||||
func TruncateSQL(s *schema.Table, cascade bool, restartIdentity bool) (string, error) {
|
||||
query := fmt.Sprintf("TRUNCATE TABLE %s", utils.SQLQuoteIdent(s.Name))
|
||||
if cascade {
|
||||
query += " CASCADE"
|
||||
if s == nil {
|
||||
return "", fmt.Errorf("nil table provided")
|
||||
}
|
||||
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
|
||||
return "", fmt.Errorf("invalid table name: %q", s.Name)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("TRUNCATE TABLE %s", utils.SQLQuoteIdent(s.Name))
|
||||
if restartIdentity {
|
||||
query += " RESTART IDENTITY"
|
||||
}
|
||||
if cascade {
|
||||
query += " CASCADE"
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,16 @@ import (
|
||||
)
|
||||
|
||||
func UpdateSQL(s *schema.Table, data map[string]any, where string) (string, error) {
|
||||
// Create a map for quick field lookup
|
||||
if s == nil {
|
||||
return "", fmt.Errorf("nil table provided")
|
||||
}
|
||||
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
|
||||
return "", fmt.Errorf("invalid table name: %q", s.Name)
|
||||
}
|
||||
if s.PrimaryKey == "" || !utils.IsValidIdentifier(s.PrimaryKey) {
|
||||
return "", fmt.Errorf("invalid primary key: %q", s.PrimaryKey)
|
||||
}
|
||||
|
||||
fieldMap := make(map[string]*schema.Field)
|
||||
for i := range s.Schema {
|
||||
fieldMap[s.Schema[i].Name] = &s.Schema[i]
|
||||
@@ -17,7 +26,10 @@ func UpdateSQL(s *schema.Table, data map[string]any, where string) (string, erro
|
||||
|
||||
setClause := make([]string, 0, len(data))
|
||||
for field, value := range data {
|
||||
// Check if field is of type "hash" and needs hashing
|
||||
if !utils.IsValidIdentifier(field) {
|
||||
return "", fmt.Errorf("invalid field name: %q", field)
|
||||
}
|
||||
|
||||
if schemaField, exists := fieldMap[field]; exists && schemaField.Type == "hash" && schemaField.Algorithm != nil {
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
hashed, err := utils.Hash(valueStr, *schemaField.Algorithm)
|
||||
@@ -26,8 +38,19 @@ func UpdateSQL(s *schema.Table, data map[string]any, where string) (string, erro
|
||||
}
|
||||
value = hashed
|
||||
}
|
||||
|
||||
setClause = append(setClause, fmt.Sprintf("%s = %s", utils.SQLQuoteIdent(field), utils.SQLQuoteValue(value)))
|
||||
}
|
||||
|
||||
setClause = append(setClause, "updated_at = CURRENT_TIMESTAMP")
|
||||
return fmt.Sprintf("UPDATE %s SET %s WHERE %s = %s", utils.SQLQuoteIdent(s.Name), strings.Join(setClause, ", "), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"UPDATE %s SET %s WHERE %s = %s",
|
||||
utils.SQLQuoteIdent(s.Name),
|
||||
strings.Join(setClause, ", "),
|
||||
utils.SQLQuoteIdent(s.PrimaryKey),
|
||||
utils.SQLQuoteValue(where),
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ type Reference struct {
|
||||
OnUpdate string `json:"on_update"`
|
||||
}
|
||||
|
||||
// Mapping of field types to SQL types
|
||||
var fieldTypeToSQLType = map[string]string{
|
||||
"string": "VARCHAR",
|
||||
"int": "INTEGER",
|
||||
@@ -60,37 +59,68 @@ func (t *Table) JSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
func (f *Field) SQL() string {
|
||||
quotedName := utils.SQLQuoteIdent(f.Name)
|
||||
if !utils.IsValidIdentifier(f.Name) {
|
||||
return ""
|
||||
}
|
||||
|
||||
sqlType := fieldTypeToSQLType[f.Type]
|
||||
if sqlType == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
quotedName := utils.SQLQuoteIdent(f.Name)
|
||||
sql := fmt.Sprintf("%s %s", quotedName, sqlType)
|
||||
|
||||
if f.Nullable != nil && !*f.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
|
||||
if f.Primary != nil && *f.Primary {
|
||||
sql += " PRIMARY KEY"
|
||||
if f.Default == nil && f.Type == "uuid" {
|
||||
sql += " DEFAULT uuid_generate_v4()"
|
||||
}
|
||||
}
|
||||
|
||||
if f.Unique != nil && *f.Unique {
|
||||
sql += " UNIQUE"
|
||||
}
|
||||
|
||||
if f.Default != nil && f.Primary == nil {
|
||||
sql += fmt.Sprintf(" DEFAULT %s", utils.SQLQuoteValue(*f.Default))
|
||||
def := *f.Default
|
||||
if utils.IsValidDefault(def) {
|
||||
switch {
|
||||
case utils.IsValidDefault(def):
|
||||
if utils.IsValidDefault(def) && (def == "CURRENT_TIMESTAMP" || def == "now()" || def == "uuid_generate_v4()") {
|
||||
sql += fmt.Sprintf(" DEFAULT %s", def)
|
||||
} else {
|
||||
sql += fmt.Sprintf(" DEFAULT %s", utils.SQLQuoteValue(def))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if f.References != nil {
|
||||
sql += fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(f.References.Table), utils.SQLQuoteIdent(f.References.Column))
|
||||
if f.References.OnDelete != "" {
|
||||
sql += fmt.Sprintf(" ON DELETE %s", f.References.OnDelete)
|
||||
ref := f.References
|
||||
if !utils.IsValidIdentifier(ref.Table) || !utils.IsValidIdentifier(ref.Column) {
|
||||
return ""
|
||||
}
|
||||
if f.References.OnUpdate != "" {
|
||||
sql += fmt.Sprintf(" ON UPDATE %s", f.References.OnUpdate)
|
||||
sql += fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(ref.Table), utils.SQLQuoteIdent(ref.Column))
|
||||
if ref.OnDelete != "" {
|
||||
action, err := utils.SanitizeOnAction(ref.OnDelete)
|
||||
if err == nil && action != "" {
|
||||
sql += fmt.Sprintf(" ON DELETE %s", action)
|
||||
}
|
||||
}
|
||||
if ref.OnUpdate != "" {
|
||||
action, err := utils.SanitizeOnAction(ref.OnUpdate)
|
||||
if err == nil && action != "" {
|
||||
sql += fmt.Sprintf(" ON UPDATE %s", action)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
@@ -98,5 +128,11 @@ func (f *Field) SQLReferences() string {
|
||||
if f.References == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(f.References.Table), utils.SQLQuoteIdent(f.References.Column))
|
||||
ref := f.References
|
||||
|
||||
if !utils.IsValidIdentifier(ref.Table) || !utils.IsValidIdentifier(ref.Column) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(ref.Table), utils.SQLQuoteIdent(ref.Column))
|
||||
}
|
||||
|
||||
53
utils/sql.go
53
utils/sql.go
@@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -12,7 +13,6 @@ func SQLQuoteIdent(id string) string {
|
||||
func SQLQuoteValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Escape single quotes by doubling them (PostgreSQL standard)
|
||||
escaped := strings.ReplaceAll(v, "'", "''")
|
||||
return fmt.Sprintf("'%s'", escaped)
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
@@ -24,8 +24,57 @@ func SQLQuoteValue(value any) string {
|
||||
case nil:
|
||||
return "NULL"
|
||||
}
|
||||
// For unknown types, convert to string and escape
|
||||
str := fmt.Sprintf("%v", value)
|
||||
escaped := strings.ReplaceAll(str, "'", "''")
|
||||
return fmt.Sprintf("'%s'", escaped)
|
||||
}
|
||||
|
||||
var identifierRe = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]{0,62}$`)
|
||||
|
||||
func IsValidIdentifier(id string) bool {
|
||||
return identifierRe.MatchString(id)
|
||||
}
|
||||
|
||||
var allowedOnActions = map[string]bool{
|
||||
"CASCADE": true,
|
||||
"SET NULL": true,
|
||||
"NO ACTION": true,
|
||||
"RESTRICT": true,
|
||||
"SET DEFAULT": true,
|
||||
}
|
||||
|
||||
func SanitizeOnAction(s string) (string, error) {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return "", nil
|
||||
}
|
||||
u := strings.ToUpper(strings.Join(strings.Fields(s), " "))
|
||||
if allowedOnActions[u] {
|
||||
return u, nil
|
||||
}
|
||||
return "", fmt.Errorf("invalid action: %q", s)
|
||||
}
|
||||
|
||||
var allowedDefaultFuncs = map[string]bool{
|
||||
"CURRENT_TIMESTAMP": true,
|
||||
"NOW()": true,
|
||||
"UUID_GENERATE_V4()": true,
|
||||
}
|
||||
|
||||
func IsValidDefault(val string) bool {
|
||||
v := strings.TrimSpace(val)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`^[+-]?\d+(\.\d+)?$`).MatchString(v) {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
|
||||
return true
|
||||
}
|
||||
if allowedDefaultFuncs[strings.ToUpper(v)] {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user