package proxy import ( "fmt" "io" "net" "sync" "time" "git.secnex.io/secnex/masterlog" ) // peekConn wraps a connection to allow peeking at data without consuming it type peekConn struct { net.Conn peeked []byte peekedOffset int } func newPeekConn(conn net.Conn) *peekConn { return &peekConn{Conn: conn} } func (p *peekConn) Read(b []byte) (int, error) { if p.peekedOffset < len(p.peeked) { n := copy(b, p.peeked[p.peekedOffset:]) p.peekedOffset += n return n, nil } return p.Conn.Read(b) } func (p *peekConn) peek(n int) ([]byte, error) { if len(p.peeked) >= n { return p.peeked[:n], nil } needed := n - len(p.peeked) buf := make([]byte, needed) read, err := io.ReadFull(p.Conn, buf) if err != nil && err != io.EOF { return nil, err } p.peeked = append(p.peeked, buf[:read]...) return p.peeked[:len(p.peeked)], nil } // handleConnection handles a new client connection func (p *Proxy) handleConnection(clientConn net.Conn) { defer clientConn.Close() remoteAddr := clientConn.RemoteAddr().String() masterlog.Info("New connection", map[string]interface{}{"remoteAddr": remoteAddr}) // Create peek connection to inspect first bytes peekConn := newPeekConn(clientConn) // Try to extract hostname hostname := p.extractHostname(peekConn) var ( targetMapping struct { host string port int } ok bool ) // If we have a hostname (e.g. from TLS SNI), try to use it first if hostname != "" { targetMapping, ok = p.mappings[hostname] if !ok { masterlog.Info("No mapping found for hostname", map[string]interface{}{"hostname": hostname}) } } // Fallback: if we couldn't determine a hostname or no mapping exists, // but there is exactly one mapping configured, use it as a default backend. if !ok { if len(p.mappings) == 1 { for _, m := range p.mappings { targetMapping = m break } masterlog.Info("Using default backend", map[string]interface{}{"host": targetMapping.host, "port": targetMapping.port, "remoteAddr": remoteAddr, "hostname": hostname}) ok = true } } if !ok { masterlog.Error("No suitable backend found", map[string]interface{}{"remoteAddr": remoteAddr, "hostname": hostname}) return } masterlog.Info("Proxying", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port}) // Connect to backend PostgreSQL server backendAddr := net.JoinHostPort(targetMapping.host, fmt.Sprintf("%d", targetMapping.port)) backendConn, err := net.DialTimeout("tcp", backendAddr, 10*time.Second) if err != nil { masterlog.Error("Failed to connect to backend", map[string]interface{}{"backendAddr": backendAddr, "error": err}) return } defer backendConn.Close() // Start bidirectional proxying var wg sync.WaitGroup wg.Add(2) // Copy from client to backend go func() { defer wg.Done() io.Copy(backendConn, peekConn) backendConn.Close() }() // Copy from backend to client go func() { defer wg.Done() io.Copy(peekConn, backendConn) peekConn.Close() }() wg.Wait() masterlog.Info("Connection closed", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port}) } // extractHostname extracts hostname from connection (TLS SNI or PostgreSQL startup) func (p *Proxy) extractHostname(conn *peekConn) string { // Peek at first byte to determine protocol firstBytes, err := conn.peek(1) if err != nil { return "" } // TLS handshake starts with 0x16 if len(firstBytes) > 0 && firstBytes[0] == 0x16 { return p.extractHostnameFromTLS(conn) } // Otherwise, try PostgreSQL startup message return p.extractHostnameFromPostgresStartup(conn) } // extractHostnameFromTLS extracts hostname by parsing TLS ClientHello manually func (p *Proxy) extractHostnameFromTLS(conn *peekConn) string { // Read enough bytes to parse TLS ClientHello // TLS Record Header: 5 bytes (type, version, length) peekBuf, err := conn.peek(5) if err != nil || len(peekBuf) < 5 { return "" } // Check TLS record type (should be 0x16 for Handshake) if peekBuf[0] != 0x16 { return "" } // Get record length recordLength := int(peekBuf[3])<<8 | int(peekBuf[4]) // Read the full TLS record (5 bytes header + recordLength bytes) fullRecord, err := conn.peek(5 + recordLength) if err != nil || len(fullRecord) < 5+recordLength { return "" } // Parse ClientHello to extract SNI // Skip TLS record header (5 bytes) handshake := fullRecord[5:] if len(handshake) < 4 { return "" } // Handshake message type (should be 0x01 for ClientHello) if handshake[0] != 0x01 { return "" } // Handshake message length (3 bytes) handshakeLength := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3]) if len(handshake) < 4+handshakeLength { return "" } // Skip handshake header (4 bytes) and protocol version (2 bytes) offset := 4 + 2 // Read random (32 bytes) if len(handshake) < offset+32 { return "" } offset += 32 // Read session ID length (1 byte) and session ID if len(handshake) < offset+1 { return "" } sessionIDLength := int(handshake[offset]) offset++ if len(handshake) < offset+sessionIDLength { return "" } offset += sessionIDLength // Read cipher suites length (2 bytes) and cipher suites if len(handshake) < offset+2 { return "" } cipherSuitesLength := int(handshake[offset])<<8 | int(handshake[offset+1]) offset += 2 if len(handshake) < offset+cipherSuitesLength { return "" } offset += cipherSuitesLength // Read compression methods length (1 byte) and compression methods if len(handshake) < offset+1 { return "" } compressionMethodsLength := int(handshake[offset]) offset++ if len(handshake) < offset+compressionMethodsLength { return "" } offset += compressionMethodsLength // Read extensions length (2 bytes) if len(handshake) < offset+2 { return "" } extensionsLength := int(handshake[offset])<<8 | int(handshake[offset+1]) offset += 2 // Parse extensions to find SNI (extension type 0x0000) extensionsEnd := offset + extensionsLength for offset < extensionsEnd-4 { if len(handshake) < offset+4 { break } extType := int(handshake[offset])<<8 | int(handshake[offset+1]) extLength := int(handshake[offset+2])<<8 | int(handshake[offset+3]) offset += 4 if len(handshake) < offset+extLength { break } // SNI extension type is 0x0000 if extType == 0x0000 { // SNI format: list length (2 bytes) + server name list if extLength < 2 { break } sniListLength := int(handshake[offset])<<8 | int(handshake[offset+1]) sniOffset := offset + 2 if len(handshake) < sniOffset+sniListLength { break } // Parse server name entry if sniListLength >= 3 { // Name type (1 byte) - 0x00 for hostname nameType := handshake[sniOffset] if nameType == 0x00 { nameLength := int(handshake[sniOffset+1])<<8 | int(handshake[sniOffset+2]) if len(handshake) >= sniOffset+3+nameLength { hostname := string(handshake[sniOffset+3 : sniOffset+3+nameLength]) return hostname } } } } offset += extLength } return "" } // extractHostnameFromPostgresStartup extracts hostname from PostgreSQL startup message func (p *Proxy) extractHostnameFromPostgresStartup(conn *peekConn) string { // Peek at first 4 bytes to get message length lengthBuf, err := conn.peek(4) if err != nil || len(lengthBuf) < 4 { return "" } length := int(lengthBuf[0])<<24 | int(lengthBuf[1])<<16 | int(lengthBuf[2])<<8 | int(lengthBuf[3]) if length < 8 || length > 10000 { return "" } // Peek at the full startup message startupBuf, err := conn.peek(length) if err != nil || len(startupBuf) < length { return "" } // Skip length (4 bytes) and protocol version (4 bytes) offset := 8 // Parse key-value pairs for offset < len(startupBuf)-1 { // Find null terminator for key keyEnd := offset for keyEnd < len(startupBuf) && startupBuf[keyEnd] != 0 { keyEnd++ } if keyEnd >= len(startupBuf) { break } _ = string(startupBuf[offset:keyEnd]) // key (not used, PostgreSQL startup doesn't contain hostname) offset = keyEnd + 1 // Find null terminator for value valueEnd := offset for valueEnd < len(startupBuf) && startupBuf[valueEnd] != 0 { valueEnd++ } if valueEnd >= len(startupBuf) { break } _ = string(startupBuf[offset:valueEnd]) // value (not used, PostgreSQL startup doesn't contain hostname) offset = valueEnd + 1 // PostgreSQL startup message doesn't contain hostname directly // Hostname should come from TLS SNI for TLS connections // If we hit the final null byte, we're done if offset >= len(startupBuf) || startupBuf[offset] == 0 { break } } return "" }