Files
pgproxy/app/proxy/conn.go
Björn Benouarets c6b269cf35 init: Initial commit
2025-12-15 14:13:43 +01:00

342 lines
8.5 KiB
Go

package proxy
import (
"fmt"
"io"
"net"
"sync"
"time"
"git.secnex.io/secnex/masterlog"
)
// peekConn wraps a connection to allow peeking at data without consuming it
type peekConn struct {
net.Conn
peeked []byte
peekedOffset int
}
func newPeekConn(conn net.Conn) *peekConn {
return &peekConn{Conn: conn}
}
func (p *peekConn) Read(b []byte) (int, error) {
if p.peekedOffset < len(p.peeked) {
n := copy(b, p.peeked[p.peekedOffset:])
p.peekedOffset += n
return n, nil
}
return p.Conn.Read(b)
}
func (p *peekConn) peek(n int) ([]byte, error) {
if len(p.peeked) >= n {
return p.peeked[:n], nil
}
needed := n - len(p.peeked)
buf := make([]byte, needed)
read, err := io.ReadFull(p.Conn, buf)
if err != nil && err != io.EOF {
return nil, err
}
p.peeked = append(p.peeked, buf[:read]...)
return p.peeked[:len(p.peeked)], nil
}
// handleConnection handles a new client connection
func (p *Proxy) handleConnection(clientConn net.Conn) {
defer clientConn.Close()
remoteAddr := clientConn.RemoteAddr().String()
masterlog.Info("New connection", map[string]interface{}{"remoteAddr": remoteAddr})
// Create peek connection to inspect first bytes
peekConn := newPeekConn(clientConn)
// Try to extract hostname
hostname := p.extractHostname(peekConn)
var (
targetMapping struct {
host string
port int
}
ok bool
)
// If we have a hostname (e.g. from TLS SNI), try to use it first
if hostname != "" {
targetMapping, ok = p.mappings[hostname]
if !ok {
masterlog.Info("No mapping found for hostname", map[string]interface{}{"hostname": hostname})
}
}
// Fallback: if we couldn't determine a hostname or no mapping exists,
// but there is exactly one mapping configured, use it as a default backend.
if !ok {
if len(p.mappings) == 1 {
for _, m := range p.mappings {
targetMapping = m
break
}
masterlog.Info("Using default backend", map[string]interface{}{"host": targetMapping.host, "port": targetMapping.port, "remoteAddr": remoteAddr, "hostname": hostname})
ok = true
}
}
if !ok {
masterlog.Error("No suitable backend found", map[string]interface{}{"remoteAddr": remoteAddr, "hostname": hostname})
return
}
masterlog.Info("Proxying", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port})
// Connect to backend PostgreSQL server
backendAddr := net.JoinHostPort(targetMapping.host, fmt.Sprintf("%d", targetMapping.port))
backendConn, err := net.DialTimeout("tcp", backendAddr, 10*time.Second)
if err != nil {
masterlog.Error("Failed to connect to backend", map[string]interface{}{"backendAddr": backendAddr, "error": err})
return
}
defer backendConn.Close()
// Start bidirectional proxying
var wg sync.WaitGroup
wg.Add(2)
// Copy from client to backend
go func() {
defer wg.Done()
io.Copy(backendConn, peekConn)
backendConn.Close()
}()
// Copy from backend to client
go func() {
defer wg.Done()
io.Copy(peekConn, backendConn)
peekConn.Close()
}()
wg.Wait()
masterlog.Info("Connection closed", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port})
}
// extractHostname extracts hostname from connection (TLS SNI or PostgreSQL startup)
func (p *Proxy) extractHostname(conn *peekConn) string {
// Peek at first byte to determine protocol
firstBytes, err := conn.peek(1)
if err != nil {
return ""
}
// TLS handshake starts with 0x16
if len(firstBytes) > 0 && firstBytes[0] == 0x16 {
return p.extractHostnameFromTLS(conn)
}
// Otherwise, try PostgreSQL startup message
return p.extractHostnameFromPostgresStartup(conn)
}
// extractHostnameFromTLS extracts hostname by parsing TLS ClientHello manually
func (p *Proxy) extractHostnameFromTLS(conn *peekConn) string {
// Read enough bytes to parse TLS ClientHello
// TLS Record Header: 5 bytes (type, version, length)
peekBuf, err := conn.peek(5)
if err != nil || len(peekBuf) < 5 {
return ""
}
// Check TLS record type (should be 0x16 for Handshake)
if peekBuf[0] != 0x16 {
return ""
}
// Get record length
recordLength := int(peekBuf[3])<<8 | int(peekBuf[4])
// Read the full TLS record (5 bytes header + recordLength bytes)
fullRecord, err := conn.peek(5 + recordLength)
if err != nil || len(fullRecord) < 5+recordLength {
return ""
}
// Parse ClientHello to extract SNI
// Skip TLS record header (5 bytes)
handshake := fullRecord[5:]
if len(handshake) < 4 {
return ""
}
// Handshake message type (should be 0x01 for ClientHello)
if handshake[0] != 0x01 {
return ""
}
// Handshake message length (3 bytes)
handshakeLength := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3])
if len(handshake) < 4+handshakeLength {
return ""
}
// Skip handshake header (4 bytes) and protocol version (2 bytes)
offset := 4 + 2
// Read random (32 bytes)
if len(handshake) < offset+32 {
return ""
}
offset += 32
// Read session ID length (1 byte) and session ID
if len(handshake) < offset+1 {
return ""
}
sessionIDLength := int(handshake[offset])
offset++
if len(handshake) < offset+sessionIDLength {
return ""
}
offset += sessionIDLength
// Read cipher suites length (2 bytes) and cipher suites
if len(handshake) < offset+2 {
return ""
}
cipherSuitesLength := int(handshake[offset])<<8 | int(handshake[offset+1])
offset += 2
if len(handshake) < offset+cipherSuitesLength {
return ""
}
offset += cipherSuitesLength
// Read compression methods length (1 byte) and compression methods
if len(handshake) < offset+1 {
return ""
}
compressionMethodsLength := int(handshake[offset])
offset++
if len(handshake) < offset+compressionMethodsLength {
return ""
}
offset += compressionMethodsLength
// Read extensions length (2 bytes)
if len(handshake) < offset+2 {
return ""
}
extensionsLength := int(handshake[offset])<<8 | int(handshake[offset+1])
offset += 2
// Parse extensions to find SNI (extension type 0x0000)
extensionsEnd := offset + extensionsLength
for offset < extensionsEnd-4 {
if len(handshake) < offset+4 {
break
}
extType := int(handshake[offset])<<8 | int(handshake[offset+1])
extLength := int(handshake[offset+2])<<8 | int(handshake[offset+3])
offset += 4
if len(handshake) < offset+extLength {
break
}
// SNI extension type is 0x0000
if extType == 0x0000 {
// SNI format: list length (2 bytes) + server name list
if extLength < 2 {
break
}
sniListLength := int(handshake[offset])<<8 | int(handshake[offset+1])
sniOffset := offset + 2
if len(handshake) < sniOffset+sniListLength {
break
}
// Parse server name entry
if sniListLength >= 3 {
// Name type (1 byte) - 0x00 for hostname
nameType := handshake[sniOffset]
if nameType == 0x00 {
nameLength := int(handshake[sniOffset+1])<<8 | int(handshake[sniOffset+2])
if len(handshake) >= sniOffset+3+nameLength {
hostname := string(handshake[sniOffset+3 : sniOffset+3+nameLength])
return hostname
}
}
}
}
offset += extLength
}
return ""
}
// extractHostnameFromPostgresStartup extracts hostname from PostgreSQL startup message
func (p *Proxy) extractHostnameFromPostgresStartup(conn *peekConn) string {
// Peek at first 4 bytes to get message length
lengthBuf, err := conn.peek(4)
if err != nil || len(lengthBuf) < 4 {
return ""
}
length := int(lengthBuf[0])<<24 | int(lengthBuf[1])<<16 | int(lengthBuf[2])<<8 | int(lengthBuf[3])
if length < 8 || length > 10000 {
return ""
}
// Peek at the full startup message
startupBuf, err := conn.peek(length)
if err != nil || len(startupBuf) < length {
return ""
}
// Skip length (4 bytes) and protocol version (4 bytes)
offset := 8
// Parse key-value pairs
for offset < len(startupBuf)-1 {
// Find null terminator for key
keyEnd := offset
for keyEnd < len(startupBuf) && startupBuf[keyEnd] != 0 {
keyEnd++
}
if keyEnd >= len(startupBuf) {
break
}
_ = string(startupBuf[offset:keyEnd]) // key (not used, PostgreSQL startup doesn't contain hostname)
offset = keyEnd + 1
// Find null terminator for value
valueEnd := offset
for valueEnd < len(startupBuf) && startupBuf[valueEnd] != 0 {
valueEnd++
}
if valueEnd >= len(startupBuf) {
break
}
_ = string(startupBuf[offset:valueEnd]) // value (not used, PostgreSQL startup doesn't contain hostname)
offset = valueEnd + 1
// PostgreSQL startup message doesn't contain hostname directly
// Hostname should come from TLS SNI for TLS connections
// If we hit the final null byte, we're done
if offset >= len(startupBuf) || startupBuf[offset] == 0 {
break
}
}
return ""
}