feat(sql): SQL Injection
This commit is contained in:
@@ -1,66 +0,0 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
)
|
||||
|
||||
var (
|
||||
requiredAttributes = [2]string{"name", "type"}
|
||||
onlySupportedTypes = []string{"string", "int", "float", "bool", "date", "time", "datetime", "uuid", "json", "hash"}
|
||||
)
|
||||
|
||||
func ValidateField(field *Field) error {
|
||||
if err := ValidateHashAlgorithm(field); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAttributes(field); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateFieldType(field); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAttributes(field *Field) error {
|
||||
// Check if the field has keys that are necessary for the field
|
||||
for _, attribute := range requiredAttributes {
|
||||
if _, ok := map[string]any{field.Name: field.Type}[attribute]; !ok {
|
||||
return fmt.Errorf("attribute %s is required", attribute)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateFieldType(field *Field) error {
|
||||
if !slices.Contains(onlySupportedTypes, field.Type) {
|
||||
return fmt.Errorf("unsupported type: %s", field.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateHashAlgorithm(field *Field) error {
|
||||
if field.Type != "hash" {
|
||||
return nil
|
||||
}
|
||||
if field.Algorithm == nil {
|
||||
return fmt.Errorf("algorithm is required for hash field")
|
||||
}
|
||||
var hashAlgorithms = []string{"argon2", "bcrypt", "md5", "sha256", "sha512"}
|
||||
for _, algorithm := range hashAlgorithms {
|
||||
if *field.Algorithm == algorithm {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unsupported hash algorithm: %s", *field.Algorithm)
|
||||
}
|
||||
|
||||
func (t *Table) Validate() error {
|
||||
for _, field := range t.Schema {
|
||||
if err := ValidateField(&field); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2,9 +2,6 @@ package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"git.secnex.io/secnex/pgson/utils"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
@@ -32,19 +29,6 @@ type Reference struct {
|
||||
OnUpdate string `json:"on_update"`
|
||||
}
|
||||
|
||||
var fieldTypeToSQLType = map[string]string{
|
||||
"string": "VARCHAR",
|
||||
"int": "INTEGER",
|
||||
"float": "FLOAT",
|
||||
"bool": "BOOLEAN",
|
||||
"date": "DATE",
|
||||
"time": "TIME",
|
||||
"datetime": "TIMESTAMP",
|
||||
"uuid": "UUID",
|
||||
"json": "JSONB",
|
||||
"hash": "VARCHAR",
|
||||
}
|
||||
|
||||
func NewTable(data []byte) (*Table, error) {
|
||||
var table Table
|
||||
err := json.Unmarshal(data, &table)
|
||||
@@ -57,82 +41,3 @@ func NewTable(data []byte) (*Table, error) {
|
||||
func (t *Table) JSON() ([]byte, error) {
|
||||
return json.Marshal(t)
|
||||
}
|
||||
|
||||
func (f *Field) SQL() string {
|
||||
if !utils.IsValidIdentifier(f.Name) {
|
||||
return ""
|
||||
}
|
||||
|
||||
sqlType := fieldTypeToSQLType[f.Type]
|
||||
if sqlType == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
quotedName := utils.SQLQuoteIdent(f.Name)
|
||||
sql := fmt.Sprintf("%s %s", quotedName, sqlType)
|
||||
|
||||
if f.Nullable != nil && !*f.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
|
||||
if f.Primary != nil && *f.Primary {
|
||||
sql += " PRIMARY KEY"
|
||||
if f.Default == nil && f.Type == "uuid" {
|
||||
sql += " DEFAULT uuid_generate_v4()"
|
||||
}
|
||||
}
|
||||
|
||||
if f.Unique != nil && *f.Unique {
|
||||
sql += " UNIQUE"
|
||||
}
|
||||
|
||||
if f.Default != nil && f.Primary == nil {
|
||||
def := *f.Default
|
||||
if utils.IsValidDefault(def) {
|
||||
switch {
|
||||
case utils.IsValidDefault(def):
|
||||
if utils.IsValidDefault(def) && (def == "CURRENT_TIMESTAMP" || def == "now()" || def == "uuid_generate_v4()") {
|
||||
sql += fmt.Sprintf(" DEFAULT %s", def)
|
||||
} else {
|
||||
sql += fmt.Sprintf(" DEFAULT %s", utils.SQLQuoteValue(def))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if f.References != nil {
|
||||
ref := f.References
|
||||
if !utils.IsValidIdentifier(ref.Table) || !utils.IsValidIdentifier(ref.Column) {
|
||||
return ""
|
||||
}
|
||||
sql += fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(ref.Table), utils.SQLQuoteIdent(ref.Column))
|
||||
if ref.OnDelete != "" {
|
||||
action, err := utils.SanitizeOnAction(ref.OnDelete)
|
||||
if err == nil && action != "" {
|
||||
sql += fmt.Sprintf(" ON DELETE %s", action)
|
||||
}
|
||||
}
|
||||
if ref.OnUpdate != "" {
|
||||
action, err := utils.SanitizeOnAction(ref.OnUpdate)
|
||||
if err == nil && action != "" {
|
||||
sql += fmt.Sprintf(" ON UPDATE %s", action)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func (f *Field) SQLReferences() string {
|
||||
if f.References == nil {
|
||||
return ""
|
||||
}
|
||||
ref := f.References
|
||||
|
||||
if !utils.IsValidIdentifier(ref.Table) || !utils.IsValidIdentifier(ref.Column) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" REFERENCES %s(%s)", utils.SQLQuoteIdent(ref.Table), utils.SQLQuoteIdent(ref.Column))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user