package proxy import ( "crypto/tls" "fmt" "io" "net" "sync" "time" "git.secnex.io/secnex/masterlog" ) // peekConn wraps a connection to allow peeking at data without consuming it type peekConn struct { net.Conn peeked []byte peekedOffset int } func newPeekConn(conn net.Conn) *peekConn { return &peekConn{Conn: conn} } func (p *peekConn) Read(b []byte) (int, error) { if p.peekedOffset < len(p.peeked) { n := copy(b, p.peeked[p.peekedOffset:]) p.peekedOffset += n return n, nil } return p.Conn.Read(b) } func (p *peekConn) peek(n int) ([]byte, error) { if len(p.peeked) >= n { return p.peeked[:n], nil } needed := n - len(p.peeked) buf := make([]byte, needed) read, err := io.ReadFull(p.Conn, buf) if err != nil && err != io.EOF { return nil, err } p.peeked = append(p.peeked, buf[:read]...) return p.peeked[:len(p.peeked)], nil } // handleConnection handles a new client connection func (p *Proxy) handleConnection(clientConn net.Conn) { defer clientConn.Close() remoteAddr := clientConn.RemoteAddr().String() masterlog.Info("New connection", map[string]interface{}{"remoteAddr": remoteAddr}) var hostname string var peekConn *peekConn // Create peek connection to inspect first bytes tempPeekConn := newPeekConn(clientConn) // Check if this is a TLS handshake (first byte is 0x16) firstByte, err := tempPeekConn.peek(1) if err != nil { masterlog.Error("Failed to peek connection", map[string]interface{}{"error": err, "remoteAddr": remoteAddr}) return } // If TLS is enabled and this looks like a TLS handshake, upgrade to TLS if p.config.TLS.Enabled && len(firstByte) > 0 && firstByte[0] == 0x16 { // Get TLS config var tlsConfig *tls.Config if p.certmagic != nil { // Use Let's Encrypt certmagic tlsConfig = createTLSConfigWithLetsEncrypt(p.certmagic) } else if p.certificate != nil { // Use stored certificate for regular TLS tlsConfig = &tls.Config{ Certificates: []tls.Certificate{*p.certificate}, ClientAuth: tls.NoClientCert, } } else { // Fallback: get certificate on demand cert, err := getCertificateForMappings(p.config) if err != nil { masterlog.Error("Failed to get TLS certificate", map[string]interface{}{"error": err}) return } tlsConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, ClientAuth: tls.NoClientCert, } } // Upgrade connection to TLS tlsConn := tls.Server(clientConn, tlsConfig) if err := tlsConn.Handshake(); err != nil { masterlog.Error("TLS handshake failed", map[string]interface{}{"error": err, "remoteAddr": remoteAddr}) return } // Extract SNI from connection state state := tlsConn.ConnectionState() if state.ServerName != "" { hostname = state.ServerName masterlog.Info("Extracted hostname from TLS SNI", map[string]interface{}{"hostname": hostname}) } peekConn = newPeekConn(tlsConn) } else { // Non-TLS connection peekConn = tempPeekConn // Try to extract hostname from raw connection (won't work for non-TLS, but that's OK) hostname = p.extractHostname(peekConn) } var ( targetMapping struct { host string port int } ok bool ) // If we have a hostname (e.g. from TLS SNI), try to use it first if hostname != "" { targetMapping, ok = p.mappings[hostname] if !ok { masterlog.Info("No mapping found for hostname", map[string]interface{}{"hostname": hostname}) } } // Fallback: if we couldn't determine a hostname or no mapping exists, // but there is exactly one mapping configured, use it as a default backend. if !ok { if len(p.mappings) == 1 { for _, m := range p.mappings { targetMapping = m break } masterlog.Info("Using default backend", map[string]interface{}{"host": targetMapping.host, "port": targetMapping.port, "remoteAddr": remoteAddr, "hostname": hostname}) ok = true } } if !ok { masterlog.Error("No suitable backend found", map[string]interface{}{"remoteAddr": remoteAddr, "hostname": hostname}) return } masterlog.Info("Proxying", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port}) // If TLS is enabled, handle PostgreSQL SSL request // After TLS handshake, PostgreSQL clients may send an SSL request // We respond with 'S' (SSL supported) since encryption is already handled by TLS if p.config.TLS.Enabled { if err := p.handlePostgresSSLRequest(peekConn); err != nil { masterlog.Error("Failed to handle PostgreSQL SSL request", map[string]interface{}{"error": err}) return } } // Connect to backend PostgreSQL server backendAddr := net.JoinHostPort(targetMapping.host, fmt.Sprintf("%d", targetMapping.port)) backendConn, err := net.DialTimeout("tcp", backendAddr, 10*time.Second) if err != nil { masterlog.Error("Failed to connect to backend", map[string]interface{}{"backendAddr": backendAddr, "error": err}) return } defer backendConn.Close() // Start bidirectional proxying var wg sync.WaitGroup wg.Add(2) // Copy from client to backend go func() { defer wg.Done() io.Copy(backendConn, peekConn) backendConn.Close() }() // Copy from backend to client go func() { defer wg.Done() io.Copy(peekConn, backendConn) peekConn.Close() }() wg.Wait() masterlog.Info("Connection closed", map[string]interface{}{"hostname": hostname, "host": targetMapping.host, "port": targetMapping.port}) } // extractHostname extracts hostname from connection (TLS SNI or PostgreSQL startup) func (p *Proxy) extractHostname(conn *peekConn) string { // Peek at first byte to determine protocol firstBytes, err := conn.peek(1) if err != nil { return "" } // TLS handshake starts with 0x16 if len(firstBytes) > 0 && firstBytes[0] == 0x16 { return p.extractHostnameFromTLS(conn) } // Otherwise, try PostgreSQL startup message return p.extractHostnameFromPostgresStartup(conn) } // extractHostnameFromTLS extracts hostname by parsing TLS ClientHello manually func (p *Proxy) extractHostnameFromTLS(conn *peekConn) string { // Read enough bytes to parse TLS ClientHello // TLS Record Header: 5 bytes (type, version, length) peekBuf, err := conn.peek(5) if err != nil || len(peekBuf) < 5 { return "" } // Check TLS record type (should be 0x16 for Handshake) if peekBuf[0] != 0x16 { return "" } // Get record length recordLength := int(peekBuf[3])<<8 | int(peekBuf[4]) // Read the full TLS record (5 bytes header + recordLength bytes) fullRecord, err := conn.peek(5 + recordLength) if err != nil || len(fullRecord) < 5+recordLength { return "" } // Parse ClientHello to extract SNI // Skip TLS record header (5 bytes) handshake := fullRecord[5:] if len(handshake) < 4 { return "" } // Handshake message type (should be 0x01 for ClientHello) if handshake[0] != 0x01 { return "" } // Handshake message length (3 bytes) handshakeLength := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3]) if len(handshake) < 4+handshakeLength { return "" } // Skip handshake header (4 bytes) and protocol version (2 bytes) offset := 4 + 2 // Read random (32 bytes) if len(handshake) < offset+32 { return "" } offset += 32 // Read session ID length (1 byte) and session ID if len(handshake) < offset+1 { return "" } sessionIDLength := int(handshake[offset]) offset++ if len(handshake) < offset+sessionIDLength { return "" } offset += sessionIDLength // Read cipher suites length (2 bytes) and cipher suites if len(handshake) < offset+2 { return "" } cipherSuitesLength := int(handshake[offset])<<8 | int(handshake[offset+1]) offset += 2 if len(handshake) < offset+cipherSuitesLength { return "" } offset += cipherSuitesLength // Read compression methods length (1 byte) and compression methods if len(handshake) < offset+1 { return "" } compressionMethodsLength := int(handshake[offset]) offset++ if len(handshake) < offset+compressionMethodsLength { return "" } offset += compressionMethodsLength // Read extensions length (2 bytes) if len(handshake) < offset+2 { return "" } extensionsLength := int(handshake[offset])<<8 | int(handshake[offset+1]) offset += 2 // Parse extensions to find SNI (extension type 0x0000) extensionsEnd := offset + extensionsLength for offset < extensionsEnd-4 { if len(handshake) < offset+4 { break } extType := int(handshake[offset])<<8 | int(handshake[offset+1]) extLength := int(handshake[offset+2])<<8 | int(handshake[offset+3]) offset += 4 if len(handshake) < offset+extLength { break } // SNI extension type is 0x0000 if extType == 0x0000 { // SNI format: list length (2 bytes) + server name list if extLength < 2 { break } sniListLength := int(handshake[offset])<<8 | int(handshake[offset+1]) sniOffset := offset + 2 if len(handshake) < sniOffset+sniListLength { break } // Parse server name entry if sniListLength >= 3 { // Name type (1 byte) - 0x00 for hostname nameType := handshake[sniOffset] if nameType == 0x00 { nameLength := int(handshake[sniOffset+1])<<8 | int(handshake[sniOffset+2]) if len(handshake) >= sniOffset+3+nameLength { hostname := string(handshake[sniOffset+3 : sniOffset+3+nameLength]) return hostname } } } } offset += extLength } return "" } // extractHostnameFromPostgresStartup extracts hostname from PostgreSQL startup message func (p *Proxy) extractHostnameFromPostgresStartup(conn *peekConn) string { // Peek at first 4 bytes to get message length lengthBuf, err := conn.peek(4) if err != nil || len(lengthBuf) < 4 { return "" } length := int(lengthBuf[0])<<24 | int(lengthBuf[1])<<16 | int(lengthBuf[2])<<8 | int(lengthBuf[3]) if length < 8 || length > 10000 { return "" } // Peek at the full startup message startupBuf, err := conn.peek(length) if err != nil || len(startupBuf) < length { return "" } // Skip length (4 bytes) and protocol version (4 bytes) offset := 8 // Parse key-value pairs for offset < len(startupBuf)-1 { // Find null terminator for key keyEnd := offset for keyEnd < len(startupBuf) && startupBuf[keyEnd] != 0 { keyEnd++ } if keyEnd >= len(startupBuf) { break } _ = string(startupBuf[offset:keyEnd]) // key (not used, PostgreSQL startup doesn't contain hostname) offset = keyEnd + 1 // Find null terminator for value valueEnd := offset for valueEnd < len(startupBuf) && startupBuf[valueEnd] != 0 { valueEnd++ } if valueEnd >= len(startupBuf) { break } _ = string(startupBuf[offset:valueEnd]) // value (not used, PostgreSQL startup doesn't contain hostname) offset = valueEnd + 1 // PostgreSQL startup message doesn't contain hostname directly // Hostname should come from TLS SNI for TLS connections // If we hit the final null byte, we're done if offset >= len(startupBuf) || startupBuf[offset] == 0 { break } } return "" } // handlePostgresSSLRequest handles PostgreSQL SSL request after TLS handshake // PostgreSQL clients may send an SSL request (0x00 0x00 0x00 0x08 0x04 0xD2 0x16 0x2F) // We respond with 'S' (0x53) to indicate SSL is supported, since encryption is already handled by TLS func (p *Proxy) handlePostgresSSLRequest(conn *peekConn) error { // Peek at first 8 bytes to check for SSL request peekBuf, err := conn.peek(8) if err != nil { // If we can't peek, it's not an SSL request, continue normally return nil } // PostgreSQL SSL request: 4 bytes length (0x00 0x00 0x00 0x08) + 4 bytes code (0x04 0xD2 0x16 0x2F) sslRequest := []byte{0x00, 0x00, 0x00, 0x08, 0x04, 0xD2, 0x16, 0x2F} if len(peekBuf) >= 8 { isSSLRequest := true for i := 0; i < 8; i++ { if peekBuf[i] != sslRequest[i] { isSSLRequest = false break } } if isSSLRequest { // Consume the SSL request (read it from the connection) buf := make([]byte, 8) _, err := io.ReadFull(conn, buf) if err != nil { return fmt.Errorf("failed to read SSL request: %w", err) } // Respond with 'S' (0x53) - SSL supported // Since encryption is already handled by TLS termination, we signal SSL support _, err = conn.Write([]byte{'S'}) if err != nil { return fmt.Errorf("failed to write SSL response: %w", err) } masterlog.Info("Handled PostgreSQL SSL request", map[string]interface{}{"response": "S", "reason": "TLS already enabled"}) } } return nil }