feat(sql): SQL Injection

This commit is contained in:
Björn Benouarets
2025-11-06 11:16:01 +01:00
parent 1f5f07e624
commit 10110071eb
7 changed files with 206 additions and 26 deletions

View File

@@ -7,10 +7,58 @@ import (
"git.secnex.io/secnex/pgson/utils"
)
func DeleteSQL(s *schema.Table, where string) (string, error) {
return fmt.Sprintf("UPDATE %s SET deleted_at = CURRENT_TIMESTAMP WHERE %s = %s", utils.SQLQuoteIdent(s.Name), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
func DeleteSQL(s *schema.Table, where any) (string, error) {
if s == nil {
return "", fmt.Errorf("nil table provided")
}
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
return "", fmt.Errorf("invalid table name: %q", s.Name)
}
if s.PrimaryKey == "" || !utils.IsValidIdentifier(s.PrimaryKey) {
return "", fmt.Errorf("invalid primary key: %q", s.PrimaryKey)
}
found := false
for i := range s.Schema {
if s.Schema[i].Name == s.PrimaryKey {
found = true
break
}
}
if !found {
return "", fmt.Errorf("primary key %q not found in schema", s.PrimaryKey)
}
return fmt.Sprintf("UPDATE %s SET deleted_at = CURRENT_TIMESTAMP WHERE %s = %s",
utils.SQLQuoteIdent(s.Name),
utils.SQLQuoteIdent(s.PrimaryKey),
utils.SQLQuoteValue(where),
), nil
}
func HardDeleteSQL(s *schema.Table, where string) (string, error) {
return fmt.Sprintf("DELETE FROM %s WHERE %s = %s", utils.SQLQuoteIdent(s.Name), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
func HardDeleteSQL(s *schema.Table, where any) (string, error) {
if s == nil {
return "", fmt.Errorf("nil table provided")
}
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
return "", fmt.Errorf("invalid table name: %q", s.Name)
}
if s.PrimaryKey == "" || !utils.IsValidIdentifier(s.PrimaryKey) {
return "", fmt.Errorf("invalid primary key: %q", s.PrimaryKey)
}
found := false
for i := range s.Schema {
if s.Schema[i].Name == s.PrimaryKey {
found = true
break
}
}
if !found {
return "", fmt.Errorf("primary key %q not found in schema", s.PrimaryKey)
}
return fmt.Sprintf("DELETE FROM %s WHERE %s = %s",
utils.SQLQuoteIdent(s.Name),
utils.SQLQuoteIdent(s.PrimaryKey),
utils.SQLQuoteValue(where),
), nil
}

View File

@@ -8,5 +8,11 @@ import (
)
func DropSQL(s *schema.Table) (string, error) {
if s == nil {
return "", fmt.Errorf("nil table provided")
}
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
return "", fmt.Errorf("invalid table name: %q", s.Name)
}
return fmt.Sprintf("DROP TABLE IF EXISTS %s", utils.SQLQuoteIdent(s.Name)), nil
}

View File

@@ -9,19 +9,27 @@ import (
)
func InsertManySQL(s *schema.Table, data []map[string]any, returning bool) (string, error) {
// Keep unquoted column names for data access
if s == nil || len(data) == 0 {
return "", fmt.Errorf("invalid input: no table or data provided")
}
if !utils.IsValidIdentifier(s.Name) {
return "", fmt.Errorf("invalid table name: %q", s.Name)
}
columnNames := make([]string, 0, len(data[0]))
for column := range data[0] {
if !utils.IsValidIdentifier(column) {
return "", fmt.Errorf("invalid column name: %q", column)
}
columnNames = append(columnNames, column)
}
// Create quoted column names for SQL
columns := make([]string, len(columnNames))
for i, col := range columnNames {
columns[i] = utils.SQLQuoteIdent(col)
}
// Create a map for quick field lookup
fieldMap := make(map[string]*schema.Field)
for i := range s.Schema {
fieldMap[s.Schema[i].Name] = &s.Schema[i]
@@ -37,7 +45,7 @@ func InsertManySQL(s *schema.Table, data []map[string]any, returning bool) (stri
valueStr := fmt.Sprintf("%v", value)
hashed, err := utils.Hash(valueStr, *field.Algorithm)
if err != nil {
return "", err
return "", fmt.Errorf("hashing error for column %q: %w", colName, err)
}
value = hashed
}
@@ -49,9 +57,12 @@ func InsertManySQL(s *schema.Table, data []map[string]any, returning bool) (stri
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", utils.SQLQuoteIdent(s.Name), strings.Join(columns, ", "), strings.Join(values, ", "))
if returning {
// RETURNING the primary key
if !utils.IsValidIdentifier(s.PrimaryKey) {
return "", fmt.Errorf("invalid primary key column: %q", s.PrimaryKey)
}
query += " RETURNING " + utils.SQLQuoteIdent(s.PrimaryKey)
}
return query, nil
}

View File

@@ -8,12 +8,19 @@ import (
)
func TruncateSQL(s *schema.Table, cascade bool, restartIdentity bool) (string, error) {
query := fmt.Sprintf("TRUNCATE TABLE %s", utils.SQLQuoteIdent(s.Name))
if cascade {
query += " CASCADE"
if s == nil {
return "", fmt.Errorf("nil table provided")
}
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
return "", fmt.Errorf("invalid table name: %q", s.Name)
}
query := fmt.Sprintf("TRUNCATE TABLE %s", utils.SQLQuoteIdent(s.Name))
if restartIdentity {
query += " RESTART IDENTITY"
}
if cascade {
query += " CASCADE"
}
return query, nil
}

View File

@@ -9,7 +9,16 @@ import (
)
func UpdateSQL(s *schema.Table, data map[string]any, where string) (string, error) {
// Create a map for quick field lookup
if s == nil {
return "", fmt.Errorf("nil table provided")
}
if s.Name == "" || !utils.IsValidIdentifier(s.Name) {
return "", fmt.Errorf("invalid table name: %q", s.Name)
}
if s.PrimaryKey == "" || !utils.IsValidIdentifier(s.PrimaryKey) {
return "", fmt.Errorf("invalid primary key: %q", s.PrimaryKey)
}
fieldMap := make(map[string]*schema.Field)
for i := range s.Schema {
fieldMap[s.Schema[i].Name] = &s.Schema[i]
@@ -17,7 +26,10 @@ func UpdateSQL(s *schema.Table, data map[string]any, where string) (string, erro
setClause := make([]string, 0, len(data))
for field, value := range data {
// Check if field is of type "hash" and needs hashing
if !utils.IsValidIdentifier(field) {
return "", fmt.Errorf("invalid field name: %q", field)
}
if schemaField, exists := fieldMap[field]; exists && schemaField.Type == "hash" && schemaField.Algorithm != nil {
valueStr := fmt.Sprintf("%v", value)
hashed, err := utils.Hash(valueStr, *schemaField.Algorithm)
@@ -26,8 +38,19 @@ func UpdateSQL(s *schema.Table, data map[string]any, where string) (string, erro
}
value = hashed
}
setClause = append(setClause, fmt.Sprintf("%s = %s", utils.SQLQuoteIdent(field), utils.SQLQuoteValue(value)))
}
setClause = append(setClause, "updated_at = CURRENT_TIMESTAMP")
return fmt.Sprintf("UPDATE %s SET %s WHERE %s = %s", utils.SQLQuoteIdent(s.Name), strings.Join(setClause, ", "), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
query := fmt.Sprintf(
"UPDATE %s SET %s WHERE %s = %s",
utils.SQLQuoteIdent(s.Name),
strings.Join(setClause, ", "),
utils.SQLQuoteIdent(s.PrimaryKey),
utils.SQLQuoteValue(where),
)
return query, nil
}