init: Initial commit
This commit is contained in:
37
app/config/config.go
Normal file
37
app/config/config.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config represents the proxy configuration
|
||||
type Config struct {
|
||||
Listen struct {
|
||||
Address string `yaml:"address"`
|
||||
Port int `yaml:"port"`
|
||||
} `yaml:"listen"`
|
||||
Mappings []struct {
|
||||
External string `yaml:"external"`
|
||||
Internal string `yaml:"internal"`
|
||||
Port int `yaml:"port"` // Optional, defaults to listen port
|
||||
} `yaml:"mappings"`
|
||||
Debug bool `yaml:"debug"`
|
||||
}
|
||||
|
||||
// loadConfig loads configuration from YAML file
|
||||
func LoadConfig(configPath string) (*Config, error) {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
8
app/go.mod
Normal file
8
app/go.mod
Normal file
@@ -0,0 +1,8 @@
|
||||
module git.secnex.io/secnex/pgproxy
|
||||
|
||||
go 1.25.3
|
||||
|
||||
require (
|
||||
git.secnex.io/secnex/masterlog v0.1.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
6
app/go.sum
Normal file
6
app/go.sum
Normal file
@@ -0,0 +1,6 @@
|
||||
git.secnex.io/secnex/masterlog v0.1.0 h1:74j9CATpfeK0lxpWIQC9ag9083akwG8khi5BwLedD8E=
|
||||
git.secnex.io/secnex/masterlog v0.1.0/go.mod h1:OnDlwEzdkKMnqY+G5O9kHdhoJ6fH1llbVdXpgSc5SdM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
50
app/main.go
Normal file
50
app/main.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"git.secnex.io/secnex/masterlog"
|
||||
"git.secnex.io/secnex/pgproxy/config"
|
||||
"git.secnex.io/secnex/pgproxy/proxy"
|
||||
"git.secnex.io/secnex/pgproxy/utils"
|
||||
)
|
||||
|
||||
func main() {
|
||||
pseudonymizer := masterlog.NewPseudonymizerFromString("1234567890")
|
||||
|
||||
masterlog.SetPseudonymizer(pseudonymizer)
|
||||
masterlog.AddSensitiveFields("password")
|
||||
|
||||
// Load configuration
|
||||
configPath := utils.GetEnv("CONFIG_PATH", "config.yaml")
|
||||
if len(os.Args) > 1 {
|
||||
configPath = os.Args[1]
|
||||
}
|
||||
|
||||
config, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
masterlog.Error("Failed to load configuration", map[string]interface{}{"error": err})
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
masterlog.Info("Loaded configuration", map[string]interface{}{"configPath": configPath, "mappings": len(config.Mappings)})
|
||||
|
||||
if config.Debug {
|
||||
masterlog.SetLevel(masterlog.LevelDebug)
|
||||
} else {
|
||||
masterlog.SetLevel(masterlog.LevelInfo)
|
||||
}
|
||||
|
||||
// Create proxy
|
||||
proxy, err := proxy.NewProxy(config)
|
||||
if err != nil {
|
||||
masterlog.Error("Failed to create proxy", map[string]interface{}{"error": err})
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Start the proxy
|
||||
if err := proxy.Start(); err != nil {
|
||||
masterlog.Error("Proxy error", map[string]interface{}{"error": err})
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
341
app/proxy/conn.go
Normal file
341
app/proxy/conn.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.secnex.io/secnex/masterlog"
|
||||
)
|
||||
|
||||
// peekConn wraps a connection to allow peeking at data without consuming it
|
||||
type peekConn struct {
|
||||
net.Conn
|
||||
peeked []byte
|
||||
peekedOffset int
|
||||
}
|
||||
|
||||
func newPeekConn(conn net.Conn) *peekConn {
|
||||
return &peekConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (p *peekConn) Read(b []byte) (int, error) {
|
||||
if p.peekedOffset < len(p.peeked) {
|
||||
n := copy(b, p.peeked[p.peekedOffset:])
|
||||
p.peekedOffset += n
|
||||
return n, nil
|
||||
}
|
||||
return p.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (p *peekConn) peek(n int) ([]byte, error) {
|
||||
if len(p.peeked) >= n {
|
||||
return p.peeked[:n], nil
|
||||
}
|
||||
|
||||
needed := n - len(p.peeked)
|
||||
buf := make([]byte, needed)
|
||||
read, err := io.ReadFull(p.Conn, buf)
|
||||
if err != nil && err != io.EOF {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.peeked = append(p.peeked, buf[:read]...)
|
||||
return p.peeked[:len(p.peeked)], nil
|
||||
}
|
||||
|
||||
// handleConnection handles a new client connection
|
||||
func (p *Proxy) handleConnection(clientConn net.Conn) {
|
||||
defer clientConn.Close()
|
||||
|
||||
remoteAddr := clientConn.RemoteAddr().String()
|
||||
masterlog.Info("New connection", map[string]interface{}{"remoteAddr": remoteAddr})
|
||||
|
||||
// Create peek connection to inspect first bytes
|
||||
peekConn := newPeekConn(clientConn)
|
||||
|
||||
// Try to extract hostname
|
||||
hostname := p.extractHostname(peekConn)
|
||||
var (
|
||||
targetMapping struct {
|
||||
host string
|
||||
port int
|
||||
}
|
||||
ok bool
|
||||
)
|
||||
|
||||
// If we have a hostname (e.g. from TLS SNI), try to use it first
|
||||
if hostname != "" {
|
||||
targetMapping, ok = p.mappings[hostname]
|
||||
if !ok {
|
||||
masterlog.Info("No mapping found for hostname", map[string]interface{}{"hostname": hostname})
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: if we couldn't determine a hostname or no mapping exists,
|
||||
// but there is exactly one mapping configured, use it as a default backend.
|
||||
if !ok {
|
||||
if len(p.mappings) == 1 {
|
||||
for _, m := range p.mappings {
|
||||
targetMapping = m
|
||||
break
|
||||
}
|
||||
masterlog.Info("Using default backend", map[string]interface{}{"host": targetMapping.host, "port": targetMapping.port, "remoteAddr": remoteAddr, "hostname": hostname})
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
masterlog.Error("No suitable backend found", map[string]interface{}{"remoteAddr": remoteAddr, "hostname": hostname})
|
||||
return
|
||||
}
|
||||
|
||||
masterlog.Info("Proxying", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port})
|
||||
|
||||
// Connect to backend PostgreSQL server
|
||||
backendAddr := net.JoinHostPort(targetMapping.host, fmt.Sprintf("%d", targetMapping.port))
|
||||
backendConn, err := net.DialTimeout("tcp", backendAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
masterlog.Error("Failed to connect to backend", map[string]interface{}{"backendAddr": backendAddr, "error": err})
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
// Start bidirectional proxying
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Copy from client to backend
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(backendConn, peekConn)
|
||||
backendConn.Close()
|
||||
}()
|
||||
|
||||
// Copy from backend to client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(peekConn, backendConn)
|
||||
peekConn.Close()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
masterlog.Info("Connection closed", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port})
|
||||
}
|
||||
|
||||
// extractHostname extracts hostname from connection (TLS SNI or PostgreSQL startup)
|
||||
func (p *Proxy) extractHostname(conn *peekConn) string {
|
||||
// Peek at first byte to determine protocol
|
||||
firstBytes, err := conn.peek(1)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TLS handshake starts with 0x16
|
||||
if len(firstBytes) > 0 && firstBytes[0] == 0x16 {
|
||||
return p.extractHostnameFromTLS(conn)
|
||||
}
|
||||
|
||||
// Otherwise, try PostgreSQL startup message
|
||||
return p.extractHostnameFromPostgresStartup(conn)
|
||||
}
|
||||
|
||||
// extractHostnameFromTLS extracts hostname by parsing TLS ClientHello manually
|
||||
func (p *Proxy) extractHostnameFromTLS(conn *peekConn) string {
|
||||
// Read enough bytes to parse TLS ClientHello
|
||||
// TLS Record Header: 5 bytes (type, version, length)
|
||||
peekBuf, err := conn.peek(5)
|
||||
if err != nil || len(peekBuf) < 5 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check TLS record type (should be 0x16 for Handshake)
|
||||
if peekBuf[0] != 0x16 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Get record length
|
||||
recordLength := int(peekBuf[3])<<8 | int(peekBuf[4])
|
||||
|
||||
// Read the full TLS record (5 bytes header + recordLength bytes)
|
||||
fullRecord, err := conn.peek(5 + recordLength)
|
||||
if err != nil || len(fullRecord) < 5+recordLength {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse ClientHello to extract SNI
|
||||
// Skip TLS record header (5 bytes)
|
||||
handshake := fullRecord[5:]
|
||||
if len(handshake) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handshake message type (should be 0x01 for ClientHello)
|
||||
if handshake[0] != 0x01 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handshake message length (3 bytes)
|
||||
handshakeLength := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3])
|
||||
if len(handshake) < 4+handshakeLength {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip handshake header (4 bytes) and protocol version (2 bytes)
|
||||
offset := 4 + 2
|
||||
|
||||
// Read random (32 bytes)
|
||||
if len(handshake) < offset+32 {
|
||||
return ""
|
||||
}
|
||||
offset += 32
|
||||
|
||||
// Read session ID length (1 byte) and session ID
|
||||
if len(handshake) < offset+1 {
|
||||
return ""
|
||||
}
|
||||
sessionIDLength := int(handshake[offset])
|
||||
offset++
|
||||
if len(handshake) < offset+sessionIDLength {
|
||||
return ""
|
||||
}
|
||||
offset += sessionIDLength
|
||||
|
||||
// Read cipher suites length (2 bytes) and cipher suites
|
||||
if len(handshake) < offset+2 {
|
||||
return ""
|
||||
}
|
||||
cipherSuitesLength := int(handshake[offset])<<8 | int(handshake[offset+1])
|
||||
offset += 2
|
||||
if len(handshake) < offset+cipherSuitesLength {
|
||||
return ""
|
||||
}
|
||||
offset += cipherSuitesLength
|
||||
|
||||
// Read compression methods length (1 byte) and compression methods
|
||||
if len(handshake) < offset+1 {
|
||||
return ""
|
||||
}
|
||||
compressionMethodsLength := int(handshake[offset])
|
||||
offset++
|
||||
if len(handshake) < offset+compressionMethodsLength {
|
||||
return ""
|
||||
}
|
||||
offset += compressionMethodsLength
|
||||
|
||||
// Read extensions length (2 bytes)
|
||||
if len(handshake) < offset+2 {
|
||||
return ""
|
||||
}
|
||||
extensionsLength := int(handshake[offset])<<8 | int(handshake[offset+1])
|
||||
offset += 2
|
||||
|
||||
// Parse extensions to find SNI (extension type 0x0000)
|
||||
extensionsEnd := offset + extensionsLength
|
||||
for offset < extensionsEnd-4 {
|
||||
if len(handshake) < offset+4 {
|
||||
break
|
||||
}
|
||||
|
||||
extType := int(handshake[offset])<<8 | int(handshake[offset+1])
|
||||
extLength := int(handshake[offset+2])<<8 | int(handshake[offset+3])
|
||||
offset += 4
|
||||
|
||||
if len(handshake) < offset+extLength {
|
||||
break
|
||||
}
|
||||
|
||||
// SNI extension type is 0x0000
|
||||
if extType == 0x0000 {
|
||||
// SNI format: list length (2 bytes) + server name list
|
||||
if extLength < 2 {
|
||||
break
|
||||
}
|
||||
sniListLength := int(handshake[offset])<<8 | int(handshake[offset+1])
|
||||
sniOffset := offset + 2
|
||||
|
||||
if len(handshake) < sniOffset+sniListLength {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse server name entry
|
||||
if sniListLength >= 3 {
|
||||
// Name type (1 byte) - 0x00 for hostname
|
||||
nameType := handshake[sniOffset]
|
||||
if nameType == 0x00 {
|
||||
nameLength := int(handshake[sniOffset+1])<<8 | int(handshake[sniOffset+2])
|
||||
if len(handshake) >= sniOffset+3+nameLength {
|
||||
hostname := string(handshake[sniOffset+3 : sniOffset+3+nameLength])
|
||||
return hostname
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
offset += extLength
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHostnameFromPostgresStartup extracts hostname from PostgreSQL startup message
|
||||
func (p *Proxy) extractHostnameFromPostgresStartup(conn *peekConn) string {
|
||||
// Peek at first 4 bytes to get message length
|
||||
lengthBuf, err := conn.peek(4)
|
||||
if err != nil || len(lengthBuf) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
length := int(lengthBuf[0])<<24 | int(lengthBuf[1])<<16 | int(lengthBuf[2])<<8 | int(lengthBuf[3])
|
||||
if length < 8 || length > 10000 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Peek at the full startup message
|
||||
startupBuf, err := conn.peek(length)
|
||||
if err != nil || len(startupBuf) < length {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip length (4 bytes) and protocol version (4 bytes)
|
||||
offset := 8
|
||||
|
||||
// Parse key-value pairs
|
||||
for offset < len(startupBuf)-1 {
|
||||
// Find null terminator for key
|
||||
keyEnd := offset
|
||||
for keyEnd < len(startupBuf) && startupBuf[keyEnd] != 0 {
|
||||
keyEnd++
|
||||
}
|
||||
if keyEnd >= len(startupBuf) {
|
||||
break
|
||||
}
|
||||
|
||||
_ = string(startupBuf[offset:keyEnd]) // key (not used, PostgreSQL startup doesn't contain hostname)
|
||||
offset = keyEnd + 1
|
||||
|
||||
// Find null terminator for value
|
||||
valueEnd := offset
|
||||
for valueEnd < len(startupBuf) && startupBuf[valueEnd] != 0 {
|
||||
valueEnd++
|
||||
}
|
||||
if valueEnd >= len(startupBuf) {
|
||||
break
|
||||
}
|
||||
|
||||
_ = string(startupBuf[offset:valueEnd]) // value (not used, PostgreSQL startup doesn't contain hostname)
|
||||
offset = valueEnd + 1
|
||||
|
||||
// PostgreSQL startup message doesn't contain hostname directly
|
||||
// Hostname should come from TLS SNI for TLS connections
|
||||
|
||||
// If we hit the final null byte, we're done
|
||||
if offset >= len(startupBuf) || startupBuf[offset] == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
68
app/proxy/proxy.go
Normal file
68
app/proxy/proxy.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"git.secnex.io/secnex/masterlog"
|
||||
"git.secnex.io/secnex/pgproxy/config"
|
||||
)
|
||||
|
||||
// Proxy handles PostgreSQL connection proxying
|
||||
type Proxy struct {
|
||||
listener net.Listener
|
||||
config *config.Config
|
||||
mappings map[string]struct {
|
||||
host string
|
||||
port int
|
||||
}
|
||||
}
|
||||
|
||||
// NewProxy creates a new proxy instance
|
||||
func NewProxy(config *config.Config) (*Proxy, error) {
|
||||
addr := fmt.Sprintf("%s:%d", config.Listen.Address, config.Listen.Port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build mappings map for quick lookup
|
||||
mappings := make(map[string]struct {
|
||||
host string
|
||||
port int
|
||||
})
|
||||
for _, mapping := range config.Mappings {
|
||||
port := mapping.Port
|
||||
if port == 0 {
|
||||
port = 5432 // Default PostgreSQL port
|
||||
}
|
||||
mappings[mapping.External] = struct {
|
||||
host string
|
||||
port int
|
||||
}{
|
||||
host: mapping.Internal,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
listener: listener,
|
||||
config: config,
|
||||
mappings: mappings,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start starts the proxy server
|
||||
func (p *Proxy) Start() error {
|
||||
masterlog.Info("PostgreSQL proxy listening", map[string]interface{}{"address": p.config.Listen.Address, "port": p.config.Listen.Port})
|
||||
|
||||
for {
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
masterlog.Error("Error accepting connection", map[string]interface{}{"error": err})
|
||||
continue
|
||||
}
|
||||
|
||||
go p.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
22
app/utils/env.go
Normal file
22
app/utils/env.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetEnv(key, defaultValue string) string {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func GetEnvBool(key string, defaultValue bool) bool {
|
||||
value := strings.ToLower(GetEnv(key, ""))
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value == "true" || value == "1"
|
||||
}
|
||||
Reference in New Issue
Block a user