452 lines
12 KiB
Go
452 lines
12 KiB
Go
package proxy
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.secnex.io/secnex/masterlog"
|
|
)
|
|
|
|
// peekConn wraps a connection to allow peeking at data without consuming it
|
|
type peekConn struct {
|
|
net.Conn
|
|
peeked []byte
|
|
peekedOffset int
|
|
}
|
|
|
|
func newPeekConn(conn net.Conn) *peekConn {
|
|
return &peekConn{Conn: conn}
|
|
}
|
|
|
|
func (p *peekConn) Read(b []byte) (int, error) {
|
|
if p.peekedOffset < len(p.peeked) {
|
|
n := copy(b, p.peeked[p.peekedOffset:])
|
|
p.peekedOffset += n
|
|
return n, nil
|
|
}
|
|
return p.Conn.Read(b)
|
|
}
|
|
|
|
func (p *peekConn) peek(n int) ([]byte, error) {
|
|
if len(p.peeked) >= n {
|
|
return p.peeked[:n], nil
|
|
}
|
|
|
|
needed := n - len(p.peeked)
|
|
buf := make([]byte, needed)
|
|
read, err := io.ReadFull(p.Conn, buf)
|
|
if err != nil && err != io.EOF {
|
|
return nil, err
|
|
}
|
|
|
|
p.peeked = append(p.peeked, buf[:read]...)
|
|
return p.peeked[:len(p.peeked)], nil
|
|
}
|
|
|
|
// handleConnection handles a new client connection
|
|
func (p *Proxy) handleConnection(clientConn net.Conn) {
|
|
defer clientConn.Close()
|
|
|
|
remoteAddr := clientConn.RemoteAddr().String()
|
|
masterlog.Info("New connection", map[string]interface{}{"remoteAddr": remoteAddr})
|
|
|
|
var hostname string
|
|
var peekConn *peekConn
|
|
|
|
// Create peek connection to inspect first bytes
|
|
tempPeekConn := newPeekConn(clientConn)
|
|
|
|
// Check if this is a TLS handshake (first byte is 0x16)
|
|
firstByte, err := tempPeekConn.peek(1)
|
|
if err != nil {
|
|
masterlog.Error("Failed to peek connection", map[string]interface{}{"error": err, "remoteAddr": remoteAddr})
|
|
return
|
|
}
|
|
|
|
// If TLS is enabled and this looks like a TLS handshake, upgrade to TLS
|
|
if p.config.TLS.Enabled && len(firstByte) > 0 && firstByte[0] == 0x16 {
|
|
// Get TLS config
|
|
var tlsConfig *tls.Config
|
|
if p.certmagic != nil {
|
|
// Use Let's Encrypt certmagic
|
|
tlsConfig = createTLSConfigWithLetsEncrypt(p.certmagic)
|
|
} else if p.certificate != nil {
|
|
// Use stored certificate for regular TLS
|
|
tlsConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{*p.certificate},
|
|
ClientAuth: tls.NoClientCert,
|
|
}
|
|
} else {
|
|
// Fallback: get certificate on demand
|
|
cert, err := getCertificateForMappings(p.config)
|
|
if err != nil {
|
|
masterlog.Error("Failed to get TLS certificate", map[string]interface{}{"error": err})
|
|
return
|
|
}
|
|
tlsConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientAuth: tls.NoClientCert,
|
|
}
|
|
}
|
|
|
|
// Upgrade connection to TLS
|
|
tlsConn := tls.Server(clientConn, tlsConfig)
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
masterlog.Error("TLS handshake failed", map[string]interface{}{"error": err, "remoteAddr": remoteAddr})
|
|
return
|
|
}
|
|
|
|
// Extract SNI from connection state
|
|
state := tlsConn.ConnectionState()
|
|
if state.ServerName != "" {
|
|
hostname = state.ServerName
|
|
masterlog.Info("Extracted hostname from TLS SNI", map[string]interface{}{"hostname": hostname})
|
|
}
|
|
peekConn = newPeekConn(tlsConn)
|
|
} else {
|
|
// Non-TLS connection
|
|
peekConn = tempPeekConn
|
|
// Try to extract hostname from raw connection (won't work for non-TLS, but that's OK)
|
|
hostname = p.extractHostname(peekConn)
|
|
}
|
|
var (
|
|
targetMapping struct {
|
|
host string
|
|
port int
|
|
}
|
|
ok bool
|
|
)
|
|
|
|
// If we have a hostname (e.g. from TLS SNI), try to use it first
|
|
if hostname != "" {
|
|
targetMapping, ok = p.mappings[hostname]
|
|
if !ok {
|
|
masterlog.Info("No mapping found for hostname", map[string]interface{}{"hostname": hostname})
|
|
}
|
|
}
|
|
|
|
// Fallback: if we couldn't determine a hostname or no mapping exists,
|
|
// but there is exactly one mapping configured, use it as a default backend.
|
|
if !ok {
|
|
if len(p.mappings) == 1 {
|
|
for _, m := range p.mappings {
|
|
targetMapping = m
|
|
break
|
|
}
|
|
masterlog.Info("Using default backend", map[string]interface{}{"host": targetMapping.host, "port": targetMapping.port, "remoteAddr": remoteAddr, "hostname": hostname})
|
|
ok = true
|
|
}
|
|
}
|
|
|
|
if !ok {
|
|
masterlog.Error("No suitable backend found", map[string]interface{}{"remoteAddr": remoteAddr, "hostname": hostname})
|
|
return
|
|
}
|
|
|
|
masterlog.Info("Proxying", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port})
|
|
|
|
// If TLS is enabled, handle PostgreSQL SSL request
|
|
// After TLS handshake, PostgreSQL clients may send an SSL request
|
|
// We respond with 'S' (SSL supported) since encryption is already handled by TLS
|
|
if p.config.TLS.Enabled {
|
|
if err := p.handlePostgresSSLRequest(peekConn); err != nil {
|
|
masterlog.Error("Failed to handle PostgreSQL SSL request", map[string]interface{}{"error": err})
|
|
return
|
|
}
|
|
}
|
|
|
|
// Connect to backend PostgreSQL server
|
|
backendAddr := net.JoinHostPort(targetMapping.host, fmt.Sprintf("%d", targetMapping.port))
|
|
backendConn, err := net.DialTimeout("tcp", backendAddr, 10*time.Second)
|
|
if err != nil {
|
|
masterlog.Error("Failed to connect to backend", map[string]interface{}{"backendAddr": backendAddr, "error": err})
|
|
return
|
|
}
|
|
defer backendConn.Close()
|
|
|
|
// Start bidirectional proxying
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
// Copy from client to backend
|
|
go func() {
|
|
defer wg.Done()
|
|
io.Copy(backendConn, peekConn)
|
|
backendConn.Close()
|
|
}()
|
|
|
|
// Copy from backend to client
|
|
go func() {
|
|
defer wg.Done()
|
|
io.Copy(peekConn, backendConn)
|
|
peekConn.Close()
|
|
}()
|
|
|
|
wg.Wait()
|
|
masterlog.Info("Connection closed", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port})
|
|
}
|
|
|
|
// extractHostname extracts hostname from connection (TLS SNI or PostgreSQL startup)
|
|
func (p *Proxy) extractHostname(conn *peekConn) string {
|
|
// Peek at first byte to determine protocol
|
|
firstBytes, err := conn.peek(1)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
// TLS handshake starts with 0x16
|
|
if len(firstBytes) > 0 && firstBytes[0] == 0x16 {
|
|
return p.extractHostnameFromTLS(conn)
|
|
}
|
|
|
|
// Otherwise, try PostgreSQL startup message
|
|
return p.extractHostnameFromPostgresStartup(conn)
|
|
}
|
|
|
|
// extractHostnameFromTLS extracts hostname by parsing TLS ClientHello manually
|
|
func (p *Proxy) extractHostnameFromTLS(conn *peekConn) string {
|
|
// Read enough bytes to parse TLS ClientHello
|
|
// TLS Record Header: 5 bytes (type, version, length)
|
|
peekBuf, err := conn.peek(5)
|
|
if err != nil || len(peekBuf) < 5 {
|
|
return ""
|
|
}
|
|
|
|
// Check TLS record type (should be 0x16 for Handshake)
|
|
if peekBuf[0] != 0x16 {
|
|
return ""
|
|
}
|
|
|
|
// Get record length
|
|
recordLength := int(peekBuf[3])<<8 | int(peekBuf[4])
|
|
|
|
// Read the full TLS record (5 bytes header + recordLength bytes)
|
|
fullRecord, err := conn.peek(5 + recordLength)
|
|
if err != nil || len(fullRecord) < 5+recordLength {
|
|
return ""
|
|
}
|
|
|
|
// Parse ClientHello to extract SNI
|
|
// Skip TLS record header (5 bytes)
|
|
handshake := fullRecord[5:]
|
|
if len(handshake) < 4 {
|
|
return ""
|
|
}
|
|
|
|
// Handshake message type (should be 0x01 for ClientHello)
|
|
if handshake[0] != 0x01 {
|
|
return ""
|
|
}
|
|
|
|
// Handshake message length (3 bytes)
|
|
handshakeLength := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3])
|
|
if len(handshake) < 4+handshakeLength {
|
|
return ""
|
|
}
|
|
|
|
// Skip handshake header (4 bytes) and protocol version (2 bytes)
|
|
offset := 4 + 2
|
|
|
|
// Read random (32 bytes)
|
|
if len(handshake) < offset+32 {
|
|
return ""
|
|
}
|
|
offset += 32
|
|
|
|
// Read session ID length (1 byte) and session ID
|
|
if len(handshake) < offset+1 {
|
|
return ""
|
|
}
|
|
sessionIDLength := int(handshake[offset])
|
|
offset++
|
|
if len(handshake) < offset+sessionIDLength {
|
|
return ""
|
|
}
|
|
offset += sessionIDLength
|
|
|
|
// Read cipher suites length (2 bytes) and cipher suites
|
|
if len(handshake) < offset+2 {
|
|
return ""
|
|
}
|
|
cipherSuitesLength := int(handshake[offset])<<8 | int(handshake[offset+1])
|
|
offset += 2
|
|
if len(handshake) < offset+cipherSuitesLength {
|
|
return ""
|
|
}
|
|
offset += cipherSuitesLength
|
|
|
|
// Read compression methods length (1 byte) and compression methods
|
|
if len(handshake) < offset+1 {
|
|
return ""
|
|
}
|
|
compressionMethodsLength := int(handshake[offset])
|
|
offset++
|
|
if len(handshake) < offset+compressionMethodsLength {
|
|
return ""
|
|
}
|
|
offset += compressionMethodsLength
|
|
|
|
// Read extensions length (2 bytes)
|
|
if len(handshake) < offset+2 {
|
|
return ""
|
|
}
|
|
extensionsLength := int(handshake[offset])<<8 | int(handshake[offset+1])
|
|
offset += 2
|
|
|
|
// Parse extensions to find SNI (extension type 0x0000)
|
|
extensionsEnd := offset + extensionsLength
|
|
for offset < extensionsEnd-4 {
|
|
if len(handshake) < offset+4 {
|
|
break
|
|
}
|
|
|
|
extType := int(handshake[offset])<<8 | int(handshake[offset+1])
|
|
extLength := int(handshake[offset+2])<<8 | int(handshake[offset+3])
|
|
offset += 4
|
|
|
|
if len(handshake) < offset+extLength {
|
|
break
|
|
}
|
|
|
|
// SNI extension type is 0x0000
|
|
if extType == 0x0000 {
|
|
// SNI format: list length (2 bytes) + server name list
|
|
if extLength < 2 {
|
|
break
|
|
}
|
|
sniListLength := int(handshake[offset])<<8 | int(handshake[offset+1])
|
|
sniOffset := offset + 2
|
|
|
|
if len(handshake) < sniOffset+sniListLength {
|
|
break
|
|
}
|
|
|
|
// Parse server name entry
|
|
if sniListLength >= 3 {
|
|
// Name type (1 byte) - 0x00 for hostname
|
|
nameType := handshake[sniOffset]
|
|
if nameType == 0x00 {
|
|
nameLength := int(handshake[sniOffset+1])<<8 | int(handshake[sniOffset+2])
|
|
if len(handshake) >= sniOffset+3+nameLength {
|
|
hostname := string(handshake[sniOffset+3 : sniOffset+3+nameLength])
|
|
return hostname
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
offset += extLength
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// extractHostnameFromPostgresStartup extracts hostname from PostgreSQL startup message
|
|
func (p *Proxy) extractHostnameFromPostgresStartup(conn *peekConn) string {
|
|
// Peek at first 4 bytes to get message length
|
|
lengthBuf, err := conn.peek(4)
|
|
if err != nil || len(lengthBuf) < 4 {
|
|
return ""
|
|
}
|
|
|
|
length := int(lengthBuf[0])<<24 | int(lengthBuf[1])<<16 | int(lengthBuf[2])<<8 | int(lengthBuf[3])
|
|
if length < 8 || length > 10000 {
|
|
return ""
|
|
}
|
|
|
|
// Peek at the full startup message
|
|
startupBuf, err := conn.peek(length)
|
|
if err != nil || len(startupBuf) < length {
|
|
return ""
|
|
}
|
|
|
|
// Skip length (4 bytes) and protocol version (4 bytes)
|
|
offset := 8
|
|
|
|
// Parse key-value pairs
|
|
for offset < len(startupBuf)-1 {
|
|
// Find null terminator for key
|
|
keyEnd := offset
|
|
for keyEnd < len(startupBuf) && startupBuf[keyEnd] != 0 {
|
|
keyEnd++
|
|
}
|
|
if keyEnd >= len(startupBuf) {
|
|
break
|
|
}
|
|
|
|
_ = string(startupBuf[offset:keyEnd]) // key (not used, PostgreSQL startup doesn't contain hostname)
|
|
offset = keyEnd + 1
|
|
|
|
// Find null terminator for value
|
|
valueEnd := offset
|
|
for valueEnd < len(startupBuf) && startupBuf[valueEnd] != 0 {
|
|
valueEnd++
|
|
}
|
|
if valueEnd >= len(startupBuf) {
|
|
break
|
|
}
|
|
|
|
_ = string(startupBuf[offset:valueEnd]) // value (not used, PostgreSQL startup doesn't contain hostname)
|
|
offset = valueEnd + 1
|
|
|
|
// PostgreSQL startup message doesn't contain hostname directly
|
|
// Hostname should come from TLS SNI for TLS connections
|
|
|
|
// If we hit the final null byte, we're done
|
|
if offset >= len(startupBuf) || startupBuf[offset] == 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// handlePostgresSSLRequest handles PostgreSQL SSL request after TLS handshake
|
|
// PostgreSQL clients may send an SSL request (0x00 0x00 0x00 0x08 0x04 0xD2 0x16 0x2F)
|
|
// We respond with 'S' (0x53) to indicate SSL is supported, since encryption is already handled by TLS
|
|
func (p *Proxy) handlePostgresSSLRequest(conn *peekConn) error {
|
|
// Peek at first 8 bytes to check for SSL request
|
|
peekBuf, err := conn.peek(8)
|
|
if err != nil {
|
|
// If we can't peek, it's not an SSL request, continue normally
|
|
return nil
|
|
}
|
|
|
|
// PostgreSQL SSL request: 4 bytes length (0x00 0x00 0x00 0x08) + 4 bytes code (0x04 0xD2 0x16 0x2F)
|
|
sslRequest := []byte{0x00, 0x00, 0x00, 0x08, 0x04, 0xD2, 0x16, 0x2F}
|
|
|
|
if len(peekBuf) >= 8 {
|
|
isSSLRequest := true
|
|
for i := 0; i < 8; i++ {
|
|
if peekBuf[i] != sslRequest[i] {
|
|
isSSLRequest = false
|
|
break
|
|
}
|
|
}
|
|
|
|
if isSSLRequest {
|
|
// Consume the SSL request (read it from the connection)
|
|
buf := make([]byte, 8)
|
|
_, err := io.ReadFull(conn, buf)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read SSL request: %w", err)
|
|
}
|
|
|
|
// Respond with 'S' (0x53) - SSL supported
|
|
// Since encryption is already handled by TLS termination, we signal SSL support
|
|
_, err = conn.Write([]byte{'S'})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write SSL response: %w", err)
|
|
}
|
|
|
|
masterlog.Info("Handled PostgreSQL SSL request", map[string]interface{}{"response": "S", "reason": "TLS already enabled"})
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|