8 Commits

Author SHA1 Message Date
79ea266ab1 Merge pull request 'feat(sql): Add select query' (#2) from bbenouarets-20251110 into main
Reviewed-on: #2
2025-11-09 23:07:05 +00:00
Björn Benouarets
73341d5bf1 feat(sql): Add select query 2025-11-10 00:05:12 +01:00
02cf76ed6b Merge pull request 'Add common whitespace characters, ASCII, HTML, backslashes to literal escaping' (#1) from advanced-literal-escape into main
Reviewed-on: #1
2025-11-06 18:37:32 +00:00
Björn Benouarets
fc982db3ef feat(sql): Add common whitespace characters, ASCII, HTML, backslashes to literal escaping 2025-11-06 19:34:48 +01:00
Björn Benouarets
ad2eaa6ebb feat(sql): Add new DML (DELETE, UPDATE) 2025-11-06 17:10:24 +01:00
Björn Benouarets
4ed1cd3b88 feat(sql): Add new DML (DELETE, UPDATE) 2025-11-06 17:10:12 +01:00
Björn Benouarets
59d6c911f9 feat(sql): SQL Injection 2025-11-06 16:44:28 +01:00
Björn Benouarets
10110071eb feat(sql): SQL Injection 2025-11-06 11:16:01 +01:00
18 changed files with 642 additions and 273 deletions

172
build/alter.go Normal file
View File

@@ -0,0 +1,172 @@
package build
import (
"fmt"
"strings"
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/sql"
)
func AlterTable(s *schema.Table) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
var ddlParts, comments []string
for _, f := range s.Schema {
if err := sql.ValidateIdent(f.Name); err != nil {
return nil, err
}
pgType, err := sql.ValidateType(f.Type)
if err != nil {
return nil, err
}
colParts := []string{sql.QuoteIdent(f.Name), pgType}
if f.Nullable != nil && !*f.Nullable {
colParts = append(colParts, "NOT NULL")
}
if f.Unique != nil && *f.Unique {
colParts = append(colParts, "UNIQUE")
}
if f.Default != nil && *f.Default != "" {
def, err := sql.ValidateDefault(*f.Default)
if err != nil {
return nil, err
}
colParts = append(colParts, "DEFAULT "+def)
}
if f.References != nil {
if err := sql.ValidateIdent(f.References.Table); err != nil {
return nil, err
}
if err := sql.ValidateIdent(f.References.Column); err != nil {
return nil, err
}
onDelete, err := sql.ValidateAction(f.References.OnDelete)
if err != nil {
return nil, err
}
onUpdate, err := sql.ValidateAction(f.References.OnUpdate)
if err != nil {
return nil, err
}
fk := fmt.Sprintf("REFERENCES %s(%s) ON DELETE %s ON UPDATE %s", sql.QuoteIdent(f.References.Table), sql.QuoteIdent(f.References.Column), onDelete, onUpdate)
colParts = append(colParts, fk)
}
ddlParts = append(ddlParts, strings.Join(colParts, " "))
if f.Comment != nil {
comments = append(comments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name), *f.Comment))
}
}
ddl := fmt.Sprintf(sql.DDL_ALTER_TABLE, sql.QuoteIdent(s.Name), strings.Join(ddlParts, " "))
return &ddl, nil
}
func AddColumn(s *schema.Table, f *schema.Field) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
if err := sql.ValidateIdent(f.Name); err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DDL_ADD_COLUMN, sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name))
return &ddl, nil
}
func DropColumn(s *schema.Table, f *schema.Field) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
if err := sql.ValidateIdent(f.Name); err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DDL_DROP_COLUMN, sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name))
return &ddl, nil
}
func AlterColumn(s *schema.Table, f *schema.Field) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
pgType, err := sql.ValidateType(f.Type)
if err != nil {
return nil, err
}
colParts := []string{sql.QuoteIdent(f.Name), pgType}
if f.Nullable != nil && !*f.Nullable {
colParts = append(colParts, "NOT NULL")
}
ddl := fmt.Sprintf(sql.DDL_ALTER_COLUMN, sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name), strings.Join(colParts, " "))
if f.Default != nil && *f.Default != "" {
def, err := sql.ValidateDefault(*f.Default)
if err != nil {
return nil, err
}
colParts = append(colParts, "DEFAULT "+def)
}
return &ddl, nil
}
func AddForeignKey(s *schema.Table, f *schema.Field) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
if err := sql.ValidateIdent(f.Name); err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DDL_ADD_FOREIGN_KEY, sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name))
return &ddl, nil
}
func DropForeignKey(s *schema.Table, f *schema.Field) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
if err := sql.ValidateIdent(f.Name); err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DDL_DROP_FOREIGN_KEY, sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name))
return &ddl, nil
}
func AddConstraint(s *schema.Table, f *schema.Field) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
if err := sql.ValidateIdent(f.Name); err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DDL_ADD_CONSTRAINT, sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name))
return &ddl, nil
}
func DropConstraint(s *schema.Table, f *schema.Field) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
if err := sql.ValidateIdent(f.Name); err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DDL_DROP_CONSTRAINT, sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name))
return &ddl, nil
}

View File

@@ -2,21 +2,108 @@ package build
import (
"fmt"
"log"
"strings"
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/utils"
"git.secnex.io/secnex/pgson/sql"
)
func CreateSQL(s *schema.Table) (string, error) {
schemaParts := make([]string, len(s.Schema))
for i, field := range s.Schema {
schemaParts[i] = field.SQL()
func CreateTable(s *schema.Table) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
schemaParts = append(schemaParts, "\"created_at\" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP")
schemaParts = append(schemaParts, "\"updated_at\" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP")
schemaParts = append(schemaParts, "\"deleted_at\" TIMESTAMP NULL")
query := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", utils.SQLQuoteIdent(s.Name), strings.Join(schemaParts, ", "))
return query, nil
var cols, comments, pks []string
for _, f := range s.Schema {
if err := sql.ValidateIdent(f.Name); err != nil {
log.Printf("Invalid identifier: %s", f.Name)
return nil, err
}
pgType, err := sql.ValidateType(f.Type)
if err != nil {
log.Printf("Invalid type: %s", f.Type)
return nil, err
}
if f.Type == "hash" {
if _, ok := sql.AlgorithmMap[strings.ToLower(*f.Algorithm)]; !ok {
log.Printf("Invalid algorithm: %s", *f.Algorithm)
return nil, fmt.Errorf("invalid algorithm: %s", *f.Algorithm)
}
}
colParts := []string{sql.QuoteIdent(f.Name), pgType}
if f.Nullable != nil && !*f.Nullable {
colParts = append(colParts, "NOT NULL")
}
if f.Unique != nil && *f.Unique {
colParts = append(colParts, "UNIQUE")
}
// Auto-generate UUID for UUID primary keys if no default is specified
if f.Primary != nil && *f.Primary && f.Type == "uuid" {
if f.Default == nil || *f.Default == "" {
colParts = append(colParts, "DEFAULT gen_random_uuid()")
} else {
def, err := sql.ValidateDefault(*f.Default)
if err != nil {
return nil, err
}
colParts = append(colParts, "DEFAULT "+def)
}
} else if f.Default != nil && *f.Default != "" {
def, err := sql.ValidateDefault(*f.Default)
if err != nil {
return nil, err
}
colParts = append(colParts, "DEFAULT "+def)
}
if f.References != nil {
if err := sql.ValidateIdent(f.References.Table); err != nil {
log.Printf("Invalid references table: %s", f.References.Table)
return nil, err
}
if err := sql.ValidateIdent(f.References.Column); err != nil {
log.Printf("Invalid references column: %s", f.References.Column)
return nil, err
}
onDelete, err := sql.ValidateAction(f.References.OnDelete)
if err != nil {
log.Printf("Invalid on delete: %s", f.References.OnDelete)
return nil, err
}
onUpdate, err := sql.ValidateAction(f.References.OnUpdate)
if err != nil {
log.Printf("Invalid on update: %s", f.References.OnUpdate)
return nil, err
}
fk := fmt.Sprintf("REFERENCES %s(%s) ON DELETE %s ON UPDATE %s", sql.QuoteIdent(f.References.Table), sql.QuoteIdent(f.References.Column), onDelete, onUpdate)
colParts = append(colParts, fk)
}
cols = append(cols, strings.Join(colParts, " "))
if f.Primary != nil && *f.Primary {
pks = append(pks, sql.QuoteIdent(f.Name))
}
if f.Comment != nil {
comments = append(comments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s';", sql.QuoteIdent(s.Name), sql.QuoteIdent(f.Name), *f.Comment))
}
}
if len(pks) > 0 {
cols = append(cols, "PRIMARY KEY ("+strings.Join(pks, ", ")+")")
}
ddl := fmt.Sprintf(sql.DDL_CREATE_TABLE, sql.QuoteIdent(s.Name), strings.Join(cols, ",\n "))
if len(comments) > 0 {
ddl += "\n" + strings.Join(comments, "\n")
}
return &ddl, nil
}

View File

@@ -4,13 +4,24 @@ import (
"fmt"
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/utils"
"git.secnex.io/secnex/pgson/sql"
)
func DeleteSQL(s *schema.Table, where string) (string, error) {
return fmt.Sprintf("UPDATE %s SET deleted_at = CURRENT_TIMESTAMP WHERE %s = %s", utils.SQLQuoteIdent(s.Name), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
}
func Delete(s *schema.Table, where map[string]any) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
func HardDeleteSQL(s *schema.Table, where string) (string, error) {
return fmt.Sprintf("DELETE FROM %s WHERE %s = %s", utils.SQLQuoteIdent(s.Name), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
if len(where) == 0 {
return nil, fmt.Errorf("WHERE clause is required for DELETE to prevent accidental deletion of all rows")
}
whereClause, err := sql.BuildWhereClause(where)
if err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DML_DELETE, sql.QuoteIdent(s.Name), whereClause)
return &ddl, nil
}

View File

@@ -4,9 +4,14 @@ import (
"fmt"
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/utils"
"git.secnex.io/secnex/pgson/sql"
)
func DropSQL(s *schema.Table) (string, error) {
return fmt.Sprintf("DROP TABLE IF EXISTS %s", utils.SQLQuoteIdent(s.Name)), nil
func DropTable(s *schema.Table) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DDL_DROP_TABLE, sql.QuoteIdent(s.Name))
return &ddl, nil
}

View File

@@ -2,59 +2,32 @@ package build
import (
"fmt"
"log"
"sort"
"strings"
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/utils"
"git.secnex.io/secnex/pgson/sql"
)
func InsertManySQL(s *schema.Table, data []map[string]any, returning bool) (string, error) {
// Keep unquoted column names for data access
columnNames := make([]string, 0, len(data[0]))
for column := range data[0] {
columnNames = append(columnNames, column)
func Insert(s *schema.Table, data map[string]any) (*string, error) {
cols := []string{}
values := []string{}
// Keys of the data map
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
// Create quoted column names for SQL
columns := make([]string, len(columnNames))
for i, col := range columnNames {
columns[i] = utils.SQLQuoteIdent(col)
}
// Create a map for quick field lookup
fieldMap := make(map[string]*schema.Field)
for i := range s.Schema {
fieldMap[s.Schema[i].Name] = &s.Schema[i]
}
values := make([]string, len(data))
for i, row := range data {
rowValues := make([]string, len(columnNames))
for j, colName := range columnNames {
value := row[colName]
if field, exists := fieldMap[colName]; exists && field.Type == "hash" && field.Algorithm != nil {
valueStr := fmt.Sprintf("%v", value)
hashed, err := utils.Hash(valueStr, *field.Algorithm)
if err != nil {
return "", err
}
value = hashed
}
rowValues[j] = utils.SQLQuoteValue(value)
sort.Strings(keys)
for _, k := range keys {
if err := sql.ValidateIdent(k); err != nil {
log.Printf("Invalid column identifier: %s", k)
return nil, err
}
values[i] = fmt.Sprintf("(%s)", strings.Join(rowValues, ", "))
cols = append(cols, sql.QuoteIdent(k))
values = append(values, sql.QuoteValue(data[k]))
}
ddl := fmt.Sprintf(sql.DML_INSERT_INTO, sql.QuoteIdent(s.Name), strings.Join(cols, ", "), strings.Join(values, ", "))
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", utils.SQLQuoteIdent(s.Name), strings.Join(columns, ", "), strings.Join(values, ", "))
if returning {
// RETURNING the primary key
query += " RETURNING " + utils.SQLQuoteIdent(s.PrimaryKey)
}
return query, nil
}
func InsertSQL(s *schema.Table, data map[string]any, returning bool) (string, error) {
return InsertManySQL(s, []map[string]any{data}, returning)
return &ddl, nil
}

14
build/select.go Normal file
View File

@@ -0,0 +1,14 @@
package build
import (
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/sql"
)
func Select(s *schema.Table, where map[string]any) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
return nil, nil
}

View File

@@ -4,16 +4,14 @@ import (
"fmt"
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/utils"
"git.secnex.io/secnex/pgson/sql"
)
func TruncateSQL(s *schema.Table, cascade bool, restartIdentity bool) (string, error) {
query := fmt.Sprintf("TRUNCATE TABLE %s", utils.SQLQuoteIdent(s.Name))
if cascade {
query += " CASCADE"
func TruncateTable(s *schema.Table) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
if restartIdentity {
query += " RESTART IDENTITY"
}
return query, nil
ddl := fmt.Sprintf(sql.DDL_TRUNCATE_TABLE, sql.QuoteIdent(s.Name))
return &ddl, nil
}

View File

@@ -2,32 +2,63 @@ package build
import (
"fmt"
"log"
"sort"
"strings"
"git.secnex.io/secnex/pgson/schema"
"git.secnex.io/secnex/pgson/utils"
"git.secnex.io/secnex/pgson/sql"
)
func UpdateSQL(s *schema.Table, data map[string]any, where string) (string, error) {
// Create a map for quick field lookup
fieldMap := make(map[string]*schema.Field)
for i := range s.Schema {
fieldMap[s.Schema[i].Name] = &s.Schema[i]
func Update(s *schema.Table, data map[string]any, where map[string]any) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
setClause := make([]string, 0, len(data))
for field, value := range data {
// Check if field is of type "hash" and needs hashing
if schemaField, exists := fieldMap[field]; exists && schemaField.Type == "hash" && schemaField.Algorithm != nil {
valueStr := fmt.Sprintf("%v", value)
hashed, err := utils.Hash(valueStr, *schemaField.Algorithm)
if err != nil {
return "", err
}
value = hashed
}
setClause = append(setClause, fmt.Sprintf("%s = %s", utils.SQLQuoteIdent(field), utils.SQLQuoteValue(value)))
if len(data) == 0 {
return nil, fmt.Errorf("no columns to update")
}
setClause = append(setClause, "updated_at = CURRENT_TIMESTAMP")
return fmt.Sprintf("UPDATE %s SET %s WHERE %s = %s", utils.SQLQuoteIdent(s.Name), strings.Join(setClause, ", "), utils.SQLQuoteIdent(s.PrimaryKey), utils.SQLQuoteValue(where)), nil
if len(where) == 0 {
return nil, fmt.Errorf("WHERE clause is required for UPDATE to prevent accidental updates")
}
var setParts []string
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
if err := sql.ValidateIdent(k); err != nil {
log.Printf("Invalid column identifier: %s", k)
return nil, err
}
setParts = append(setParts, fmt.Sprintf("%s = %s", sql.QuoteIdent(k), sql.QuoteValue(data[k])))
}
whereClause, err := sql.BuildWhereClause(where)
if err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DML_UPDATE, sql.QuoteIdent(s.Name), strings.Join(setParts, ", "), whereClause)
return &ddl, nil
}
func UpdateDeletedAt(s *schema.Table, where map[string]any) (*string, error) {
if err := sql.ValidateIdent(s.Name); err != nil {
return nil, err
}
whereClause, err := sql.BuildWhereClause(where)
if err != nil {
return nil, err
}
ddl := fmt.Sprintf(sql.DML_UPDATE, sql.QuoteIdent(s.Name), "deleted_at = CURRENT_TIMESTAMP", whereClause)
return &ddl, nil
}

View File

@@ -19,6 +19,6 @@ func NewSchema(json []byte) (*schema.Table, error) {
return schema.NewTable(json)
}
func ReadJsonFileToMap(path string) ([]map[string]any, error) {
func ReadJsonFileToMap(path string) (map[string]any, error) {
return utils.ReadJsonFileToMap(path)
}

View File

@@ -5,37 +5,34 @@ import (
"git.secnex.io/secnex/pgson/schema"
)
func Create(s *schema.Table) (string, error) {
if err := s.Validate(); err != nil {
return "", err
}
return build.CreateSQL(s)
func CreateTable(schema *schema.Table) (*string, error) {
return build.CreateTable(schema)
}
func Drop(s *schema.Table) (string, error) {
return build.DropSQL(s)
func DropTable(schema *schema.Table) (*string, error) {
return build.DropTable(schema)
}
func InsertMany(s *schema.Table, data []map[string]any, returning bool) (string, error) {
return build.InsertManySQL(s, data, returning)
func TruncateTable(schema *schema.Table) (*string, error) {
return build.TruncateTable(schema)
}
func Insert(s *schema.Table, data map[string]any, returning bool) (string, error) {
return build.InsertSQL(s, data, returning)
func Select(schema *schema.Table, where map[string]any) (*string, error) {
return build.Select(schema, where)
}
func Update(s *schema.Table, data map[string]any, where string) (string, error) {
return build.UpdateSQL(s, data, where)
func Insert(schema *schema.Table, data map[string]any) (*string, error) {
return build.Insert(schema, data)
}
func Delete(s *schema.Table, where string) (string, error) {
return build.DeleteSQL(s, where)
func Update(schema *schema.Table, data map[string]any, where map[string]any) (*string, error) {
return build.Update(schema, data, where)
}
func HardDelete(s *schema.Table, where string) (string, error) {
return build.HardDeleteSQL(s, where)
func Delete(schema *schema.Table, where map[string]any) (*string, error) {
return build.UpdateDeletedAt(schema, where)
}
func Truncate(s *schema.Table, cascade bool, restartIdentity bool) (string, error) {
return build.TruncateSQL(s, cascade, restartIdentity)
func HardDelete(schema *schema.Table, where map[string]any) (*string, error) {
return build.Delete(schema, where)
}

View File

@@ -1,66 +0,0 @@
package schema
import (
"fmt"
"slices"
)
var (
requiredAttributes = [2]string{"name", "type"}
onlySupportedTypes = []string{"string", "int", "float", "bool", "date", "time", "datetime", "uuid", "json", "hash"}
)
func ValidateField(field *Field) error {
if err := ValidateHashAlgorithm(field); err != nil {
return err
}
if err := ValidateAttributes(field); err != nil {
return err
}
if err := ValidateFieldType(field); err != nil {
return err
}
return nil
}
func ValidateAttributes(field *Field) error {
// Check if the field has keys that are necessary for the field
for _, attribute := range requiredAttributes {
if _, ok := map[string]any{field.Name: field.Type}[attribute]; !ok {
return fmt.Errorf("attribute %s is required", attribute)
}
}
return nil
}
func ValidateFieldType(field *Field) error {
if !slices.Contains(onlySupportedTypes, field.Type) {
return fmt.Errorf("unsupported type: %s", field.Type)
}
return nil
}
func ValidateHashAlgorithm(field *Field) error {
if field.Type != "hash" {
return nil
}
if field.Algorithm == nil {
return fmt.Errorf("algorithm is required for hash field")
}
var hashAlgorithms = []string{"argon2", "bcrypt", "md5", "sha256", "sha512"}
for _, algorithm := range hashAlgorithms {
if *field.Algorithm == algorithm {
return nil
}
}
return fmt.Errorf("unsupported hash algorithm: %s", *field.Algorithm)
}
func (t *Table) Validate() error {
for _, field := range t.Schema {
if err := ValidateField(&field); err != nil {
return err
}
}
return nil
}

View File

@@ -2,9 +2,6 @@ package schema
import (
"encoding/json"
"fmt"
"git.secnex.io/secnex/pgson/utils"
)
type Table struct {
@@ -32,20 +29,6 @@ type Reference struct {
OnUpdate string `json:"on_update"`
}
// Mapping of field types to SQL types
var fieldTypeToSQLType = map[string]string{
"string": "VARCHAR",
"int": "INTEGER",
"float": "FLOAT",
"bool": "BOOLEAN",
"date": "DATE",
"time": "TIME",
"datetime": "TIMESTAMP",
"uuid": "UUID",
"json": "JSONB",
"hash": "VARCHAR",
}
func NewTable(data []byte) (*Table, error) {
var table Table
err := json.Unmarshal(data, &table)
@@ -58,45 +41,3 @@ func NewTable(data []byte) (*Table, error) {
func (t *Table) JSON() ([]byte, error) {
return json.Marshal(t)
}
func (f *Field) SQL() string {
quotedName := utils.SQLQuoteIdent(f.Name)
sqlType := fieldTypeToSQLType[f.Type]
if sqlType == "" {
return ""
}
sql := fmt.Sprintf("%s %s", quotedName, sqlType)
if f.Nullable != nil && !*f.Nullable {
sql += " NOT NULL"
}
if f.Primary != nil && *f.Primary {
sql += " PRIMARY KEY"
if f.Default == nil && f.Type == "uuid" {
sql += " DEFAULT uuid_generate_v4()"
}
}
if f.Unique != nil && *f.Unique {
sql += " UNIQUE"
}
if f.Default != nil && f.Primary == nil {
sql += fmt.Sprintf(" DEFAULT %s", utils.SQLQuoteValue(*f.Default))
}
if f.References != nil {
sql += fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(f.References.Table), utils.SQLQuoteIdent(f.References.Column))
if f.References.OnDelete != "" {
sql += fmt.Sprintf(" ON DELETE %s", f.References.OnDelete)
}
if f.References.OnUpdate != "" {
sql += fmt.Sprintf(" ON UPDATE %s", f.References.OnUpdate)
}
}
return sql
}
func (f *Field) SQLReferences() string {
if f.References == nil {
return ""
}
return fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(f.References.Table), utils.SQLQuoteIdent(f.References.Column))
}

15
sql/ddl.go Normal file
View File

@@ -0,0 +1,15 @@
package sql
var (
DDL_CREATE_TABLE = "CREATE TABLE IF NOT EXISTS %s (\n %s\n);"
DDL_DROP_TABLE = "DROP TABLE IF EXISTS %s;"
DDL_TRUNCATE_TABLE = "TRUNCATE TABLE %s;"
DDL_ALTER_TABLE = "ALTER TABLE %s %s;"
DDL_ADD_COLUMN = "ALTER TABLE %s ADD COLUMN %s;"
DDL_DROP_COLUMN = "ALTER TABLE %s DROP COLUMN %s;"
DDL_ALTER_COLUMN = "ALTER TABLE %s ALTER COLUMN %s %s;"
DDL_ADD_FOREIGN_KEY = "ALTER TABLE %s ADD FOREIGN KEY (%s) REFERENCES %s (%s);"
DDL_DROP_FOREIGN_KEY = "ALTER TABLE %s DROP FOREIGN KEY %s;"
DDL_ADD_CONSTRAINT = "ALTER TABLE %s ADD CONSTRAINT %s %s;"
DDL_DROP_CONSTRAINT = "ALTER TABLE %s DROP CONSTRAINT %s;"
)

7
sql/dml.go Normal file
View File

@@ -0,0 +1,7 @@
package sql
var (
DML_INSERT_INTO = "INSERT INTO %s (%s) VALUES (%s);"
DML_UPDATE = "UPDATE %s SET %s%s;"
DML_DELETE = "DELETE FROM %s%s;"
)

203
sql/validate.go Normal file
View File

@@ -0,0 +1,203 @@
package sql
import (
"fmt"
"regexp"
"sort"
"strings"
"git.secnex.io/secnex/pgson/schema"
)
var (
IdentifierRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
TypeMap = map[string]string{
"string": "VARCHAR",
"int": "INTEGER",
"float": "FLOAT",
"bool": "BOOLEAN",
"date": "DATE",
"time": "TIME",
"datetime": "TIMESTAMP",
"uuid": "UUID",
"json": "JSONB",
"hash": "VARCHAR",
}
FkActions = map[string]struct{}{
"CASCADE": {},
"RESTRICT": {},
"SET NULL": {},
"NO ACTION": {},
"SET DEFAULT": {},
}
DefaultFunctions = map[string]struct{}{
"CURRENT_TIMESTAMP": {},
"now()": {},
"uuid_generate_v4()": {},
}
AlgorithmMap = map[string]struct{}{
"argon2": {},
"bcrypt": {},
"md5": {},
"sha256": {},
"sha512": {},
}
)
func ValidateIdent(name string) error {
if !IdentifierRegex.MatchString(name) {
return fmt.Errorf("invalid identifier: %s", name)
}
return nil
}
func ValidateType(name string) (string, error) {
pg, ok := TypeMap[name]
if !ok {
return "", fmt.Errorf("invalid type: %s", name)
}
return pg, nil
}
func ValidateAction(act string) (string, error) {
a := strings.ToUpper(act)
if _, ok := FkActions[a]; !ok {
return "", fmt.Errorf("invalid action: %s", act)
}
return a, nil
}
func ValidateDefault(def string) (string, error) {
if regexp.MustCompile(`^\d+(\.\d+)?$`).MatchString(def) || regexp.MustCompile(`^'.*'$`).MatchString(def) {
return def, nil
}
if _, ok := DefaultFunctions[strings.ToLower(def)]; ok {
return def, nil
}
return "", fmt.Errorf("invalid default: %s", def)
}
func ValidateReferences(refs *schema.Reference) error {
if err := ValidateIdent(refs.Table); err != nil {
return err
}
if err := ValidateIdent(refs.Column); err != nil {
return err
}
return nil
}
func ValidateAlgorithm(algo string) error {
a := strings.ToUpper(algo)
if _, ok := AlgorithmMap[a]; !ok {
return fmt.Errorf("invalid algorithm: %s", algo)
}
return nil
}
func ValidatePrimary(primary bool) error {
if primary {
return nil
}
return fmt.Errorf("primary is required")
}
func ValidateUnique(unique bool) error {
if unique {
return nil
}
return fmt.Errorf("unique is required")
}
func ValidateComment(comment *string) error {
if comment == nil {
return nil
}
return ValidateIdent(*comment)
}
func QuoteIdent(name string) string {
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
}
func QuoteLiteral(s string) string {
s = strings.ReplaceAll(s, "\x00", "")
var b strings.Builder
b.Grow(len(s) + 2)
b.WriteByte('\'')
for _, r := range s {
switch r {
case '\'':
b.WriteString("''")
case '\\':
b.WriteString("\\\\")
case '\x00':
continue
case '\t', '\n', '\r':
b.WriteRune(r)
default:
if r < 0x20 {
b.WriteString(fmt.Sprintf("\\%03o", r))
} else {
b.WriteRune(r)
}
}
}
b.WriteByte('\'')
return b.String()
}
func QuoteValue(v any) string {
if v == nil {
return "NULL"
}
switch val := v.(type) {
case string:
if matched, _ := regexp.MatchString(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, strings.ToLower(val)); matched {
return QuoteLiteral(val)
}
return QuoteLiteral(val)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%v", val)
case float32, float64:
return fmt.Sprintf("%v", val)
case bool:
if val {
return "TRUE"
}
return "FALSE"
default:
return QuoteLiteral(fmt.Sprintf("%v", val))
}
}
func BuildWhereClause(conditions map[string]any) (string, error) {
if len(conditions) == 0 {
return "", nil
}
var clauses []string
keys := make([]string, 0, len(conditions))
for k := range conditions {
keys = append(keys, k)
}
sort.Strings(keys)
for _, col := range keys {
if err := ValidateIdent(col); err != nil {
return "", fmt.Errorf("invalid column identifier in WHERE clause: %s", col)
}
value := conditions[col]
if value == nil {
clauses = append(clauses, fmt.Sprintf("%s IS NULL", QuoteIdent(col)))
} else {
clauses = append(clauses, fmt.Sprintf("%s = %s", QuoteIdent(col), QuoteValue(value)))
}
}
return " WHERE " + strings.Join(clauses, " AND "), nil
}

View File

@@ -5,8 +5,8 @@ import (
"os"
)
func JsonToMap(j []byte) ([]map[string]any, error) {
var data []map[string]any
func JsonToMap(j []byte) (map[string]any, error) {
var data map[string]any
err := json.Unmarshal(j, &data)
if err != nil {
return nil, err
@@ -14,7 +14,7 @@ func JsonToMap(j []byte) ([]map[string]any, error) {
return data, nil
}
func ReadJsonFileToMap(path string) ([]map[string]any, error) {
func ReadJsonFileToMap(path string) (map[string]any, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err

View File

@@ -1,31 +0,0 @@
package utils
import (
"fmt"
"strings"
)
func SQLQuoteIdent(id string) string {
return `"` + strings.ReplaceAll(id, `"`, `""`) + `"`
}
func SQLQuoteValue(value any) string {
switch v := value.(type) {
case string:
// Escape single quotes by doubling them (PostgreSQL standard)
escaped := strings.ReplaceAll(v, "'", "''")
return fmt.Sprintf("'%s'", escaped)
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%v", value)
case float32, float64:
return fmt.Sprintf("%v", value)
case bool:
return fmt.Sprintf("%v", value)
case nil:
return "NULL"
}
// For unknown types, convert to string and escape
str := fmt.Sprintf("%v", value)
escaped := strings.ReplaceAll(str, "'", "''")
return fmt.Sprintf("'%s'", escaped)
}

12
utils/string.go Normal file
View File

@@ -0,0 +1,12 @@
package utils
func String(s *string) string {
if s == nil {
return ""
}
return *s
}
func Pointer(s string) *string {
return &s
}