From: Dan Fuhry Date: Tue, 8 Apr 2025 21:59:03 +0000 (-0400) Subject: [http] add PROXY protocol v1 and v2 support X-Git-Url: https://go.fuhry.dev/?a=commitdiff_plain;h=eabe61c13b25d273eddf42521581a32bf67e5de2;p=runtime.git [http] add PROXY protocol v1 and v2 support --- diff --git a/http/proxy/main.go b/http/proxy/main.go index c9c727b..25e610f 100644 --- a/http/proxy/main.go +++ b/http/proxy/main.go @@ -41,24 +41,22 @@ func main() { flag.Parse() - httpServer, err := server.Create() - if err != nil { - log.Panic(err) - } - go httpServer.ListenAndServeTLS("", "") - - log.Default().Infof("listening on HTTPS at %s", server.Listener.Addr) - - unsecureServer := server.CreateInsecure() - go unsecureServer.ListenAndServe() + go (func() { + if err := server.ListenAndServeTLS(); err != nil { + log.Panic(err) + } + })() - log.Default().Infof("listening on HTTP at %s (redirects to HTTPS only)", server.Listener.InsecureAddr) + go (func() { + if err := server.ListenAndServe(); err != nil { + log.Panic(err) + } + })() daemon.SdNotify(false, daemon.SdNotifyReady) <-ctx.Done() shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) defer shutdownCancel() - httpServer.Shutdown(shutdownCtx) - unsecureServer.Shutdown(shutdownCtx) + server.Shutdown(shutdownCtx) } diff --git a/http/proxy_protocol.go b/http/proxy_protocol.go new file mode 100644 index 0000000..6026421 --- /dev/null +++ b/http/proxy_protocol.go @@ -0,0 +1,484 @@ +package http + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "net" + "strconv" + "strings" + + "go.fuhry.dev/runtime/utils/log" +) + +type ProxyProtocolConn interface { + net.Conn + + ProxyAddr() net.Addr +} + +type proxyProtocolListener struct { + net.Listener +} + +type proxyProtocolConn struct { + net.Conn + + reader *bufio.Reader + addrOrigin net.Addr + addrProxy net.Addr + addrLocal net.Addr +} + +type ppVersionCommand uint8 +type ppFamily uint8 + +type ppHeader struct { + vc ppVersionCommand + fa ppFamily + len uint16 +} + +type ppAddressTuple struct { + transport uint8 + + srcAddr net.IP + dstAddr net.IP + srcPort uint16 + dstPort uint16 +} + +type ppUnixTuple struct { + transport uint8 + + src string + dst string +} + +type ppAddress interface { + Src() net.Addr + Dst() net.Addr + Encode() []byte +} + +func (a *ppAddressTuple) Src() net.Addr { + switch a.transport { + case trDatagram: + return &net.UDPAddr{ + IP: a.srcAddr, + Port: int(a.srcPort), + } + case trStream: + return &net.TCPAddr{ + IP: a.srcAddr, + Port: int(a.srcPort), + } + default: + return nil + } +} + +func (a *ppAddressTuple) Dst() net.Addr { + return &net.TCPAddr{ + IP: a.dstAddr, + Port: int(a.dstPort), + } +} + +func (a *ppAddressTuple) Encode() []byte { + var buf []byte + buf = append(buf, a.srcAddr...) + buf = append(buf, a.dstAddr...) + buf = append(buf, uint8(a.srcPort>>8)) + buf = append(buf, uint8(a.srcPort)) + buf = append(buf, uint8(a.dstPort>>8)) + buf = append(buf, uint8(a.dstPort)) + + return buf +} + +func (a *ppUnixTuple) toAddr(s string) net.Addr { + switch a.transport { + case trDatagram: + return &net.UnixAddr{ + Name: s, + Net: "unixgram", + } + case trStream: + return &net.UnixAddr{ + Name: s, + Net: "unix", + } + default: + return nil + } +} + +func (a *ppUnixTuple) Src() net.Addr { + return a.toAddr(a.src) +} + +func (a *ppUnixTuple) Dst() net.Addr { + return a.toAddr(a.dst) +} + +func (a *ppUnixTuple) Encode() []byte { + buf := make([]byte, v2UnixSocketLength*2) + copy(buf, []byte(a.src)) + copy(buf[v2UnixSocketLength:], []byte(a.dst)) + + return buf +} + +func newPPFamily(af, tr uint8) ppFamily { + return ppFamily(af<<4 | tr) +} + +func (f ppFamily) String() string { + var ret string + switch f.Transport() { + case trStream: + ret = "TCP" + case trDatagram: + ret = "UDP" + default: + return "UNKNOWN" + } + + switch f.AddressFamily() { + case afInet: + ret += "4" + case afInet6: + ret += "6" + default: + return "UNKNOWN" + } + return ret +} + +const ( + v2HeaderLength = 16 + v2UnixSocketLength = 108 + + version2 uint8 = 2 + + cmdLocal uint8 = 0 + cmdProxy uint8 = 1 + + afUnspec uint8 = 0 + afInet uint8 = 1 + afInet6 uint8 = 2 + afUnix uint8 = 3 + + trUnspec uint8 = 0 + trStream uint8 = 1 + trDatagram uint8 = 2 +) + +var proxyProtocolHeaderV1 = [6]byte{ + 'P', 'R', 'O', 'X', 'Y', ' ', +} +var proxyProtocolMagicV2 = [12]byte{ + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 'Q', 'U', 'I', 'T', 0x0a, +} +var ErrBadMagic = errors.New("bad magic") +var ErrV1BadSyntax = errors.New("v1 header is incorrectly formatted") +var ErrLength = errors.New("header is wrong length") +var ErrUnrecognizedVersion = errors.New("unrecognized protocol version") +var ErrUnrecognizedCommand = errors.New("unrecognized command") +var ErrUnspecifiedAddressFamily = errors.New("unspecified address family") +var ErrUnrecognizedAddressFamily = errors.New("unrecognized address family") +var ErrUnrecognizedTransportProtocol = errors.New("unrecognized transport protocol") + +func (v ppVersionCommand) Version() uint8 { + return uint8(v) >> 4 +} + +func (v ppVersionCommand) Command() uint8 { + return uint8(v) & 0x0f +} + +func (v ppFamily) AddressFamily() uint8 { + return uint8(v) >> 4 +} + +func (v ppFamily) Transport() uint8 { + return uint8(v) & 0x0f +} + +func (v ppHeader) ReadAddress(r *bufio.Reader) (ppAddress, error) { + addressBuf := make([]byte, int(v.len)) + n, err := r.Read(addressBuf) + if err != nil { + return nil, err + } + if n != int(v.len) { + return nil, fmt.Errorf("got wrong number of bytes reading address, expected %d got %d, buf: %+v", v.len, n, addressBuf) // ErrLength + } + + var expectLength int + switch v.fa.AddressFamily() { + case afUnspec: + return nil, ErrUnspecifiedAddressFamily + case afInet: + expectLength = 12 // two IPv4 addresses + 2 ports + case afInet6: + expectLength = 36 // two IPv6 addresses + 2 ports + case afUnix: + expectLength = 2 * v2UnixSocketLength // two UNIX socket addresses + default: + panic("not reached") + } + + if n != expectLength { + return nil, fmt.Errorf("proxy protocol message is wrong length for af %d, proto %d: expected %d, got %d, buf: %+v", + v.fa.AddressFamily(), v.fa.Transport(), expectLength, n, addressBuf) + } + + switch v.fa.AddressFamily() { + case afInet: + return &ppAddressTuple{ + transport: v.fa.Transport(), + srcAddr: addressBuf[0:4], + dstAddr: addressBuf[4:8], + srcPort: uint16(addressBuf[8])<<8 | uint16(addressBuf[9]), + dstPort: uint16(addressBuf[10])<<8 | uint16(addressBuf[11]), + }, nil + case afInet6: + return &ppAddressTuple{ + transport: v.fa.Transport(), + srcAddr: addressBuf[0:16], + dstAddr: addressBuf[16:32], + srcPort: uint16(addressBuf[32])<<8 | uint16(addressBuf[33]), + dstPort: uint16(addressBuf[34])<<8 | uint16(addressBuf[35]), + }, nil + case afUnix: + return &ppUnixTuple{ + src: strings.Trim(string(addressBuf[0:v2UnixSocketLength]), "\x00"), + dst: strings.Trim(string(addressBuf[v2UnixSocketLength:]), "\x00"), + }, nil + default: + panic("not reached") + } +} + +func NewProxyProtocolV2Header(cmd, af, tr uint8) ppHeader { + return ppHeader{ + vc: ppVersionCommand(version2<<4 | cmd), + fa: newPPFamily(af, tr), + } +} + +func (v ppHeader) Encode(addr ppAddress) []byte { + out := make([]byte, v2HeaderLength) + copy(out, proxyProtocolMagicV2[:]) + + out[12] = uint8(v.vc) + out[13] = uint8(v.fa) + addrBytes := addr.Encode() + out[14] = uint8(len(addrBytes) >> 8) + out[15] = uint8(len(addrBytes)) + out = append(out, addrBytes...) + + return out +} + +func parseProxyProtocolV2Header(header []byte) (*ppHeader, error) { + if len(header) != v2HeaderLength { + return nil, ErrLength + } + + if !bytes.Equal(header[:len(proxyProtocolMagicV2)], proxyProtocolMagicV2[:]) { + return nil, ErrBadMagic + } + + vc := ppVersionCommand(header[12]) + fa := ppFamily(header[13]) + + if vc.Version() != version2 { + return nil, ErrUnrecognizedVersion + } + + if vc.Command() != cmdLocal && vc.Command() != cmdProxy { + return nil, ErrUnrecognizedCommand + } + + af := fa.AddressFamily() + if af != afUnspec && af != afInet && af != afInet6 && af != afUnix { + return nil, ErrUnrecognizedAddressFamily + } + + pr := fa.Transport() + if pr != trUnspec && pr != trStream && pr != trDatagram { + return nil, ErrUnrecognizedTransportProtocol + } + + return &ppHeader{ + vc: vc, + fa: fa, + // length is a uint16 in network (big endian) order + len: uint16(header[14])<<8 | uint16(header[15]), + }, nil +} + +// PROXY protocol v2 header: +// offset len type desc +// 0 12 []byte magic +// 12 1 uint8 version +// 13 1 uint8 family +// 14 2 uint16 length +// [union] tcp/udp over IPv4 (overall 12 bytes) +// 16 4 uint32 src addr +// 20 4 uint32 dst addr +// 24 2 uint16 src port +// 26 2 uint16 dst port +// [union] tcp/udp over IPv6 (overall 36 bytes) +// 16 16 []uint8 src addr +// 32 16 []uint8 dst addr +// 48 2 uint16 src port +// 50 2 uint16 dst port +// [union] unix socket (overall 216 bytes) +// 16 108 []uint8 src addr +// 124 108 []uint8 dst addr + +func ListenProxyProtocol(network, addr string) (net.Listener, error) { + l, err := net.Listen(network, addr) + if err != nil { + return nil, err + } + return &proxyProtocolListener{l}, nil +} + +func (p *proxyProtocolListener) Accept() (net.Conn, error) { + conn, err := p.Listener.Accept() + if err != nil { + return nil, err + } + + return newProxyProtocolConn(conn) +} + +func newProxyProtocolConn(conn net.Conn) (net.Conn, error) { + var acceptErr error + reader := bufio.NewReader(conn) + header, err := reader.Peek(len(proxyProtocolMagicV2)) + if err != nil { + acceptErr = err + goto nonProxyAddr + } + + if bytes.Equal(header[:len(proxyProtocolMagicV2)], proxyProtocolMagicV2[:]) { + hdr := make([]byte, v2HeaderLength) + n, err := reader.Read(hdr) + if err != nil || n != v2HeaderLength { + acceptErr = err + goto nonProxyAddr + } + parsedHeader, err := parseProxyProtocolV2Header(hdr) + if err != nil { + acceptErr = err + goto nonProxyAddr + } + addr, err := parsedHeader.ReadAddress(reader) + if err != nil { + acceptErr = err + goto nonProxyAddr + } + return &proxyProtocolConn{ + Conn: conn, + reader: reader, + addrLocal: addr.Dst(), + addrProxy: conn.RemoteAddr(), + addrOrigin: addr.Src(), + }, nil + } else if bytes.Equal(header[:len(proxyProtocolHeaderV1)], proxyProtocolHeaderV1[:]) { + buf, err := reader.ReadBytes('\n') + if err != nil { + acceptErr = ErrV1BadSyntax + goto nonProxyAddr + } + hdr := strings.Split(strings.Trim(string(buf), "\r\n"), " ") + if len(hdr) != 6 { + acceptErr = ErrLength + goto nonProxyAddr + } + + var ppAddr ppAddress + switch hdr[1] { + case newPPFamily(afInet, trStream).String(), newPPFamily(afInet6, trStream).String(): + srcPort, err := strconv.Atoi(hdr[4]) + if err != nil { + acceptErr = ErrV1BadSyntax + goto nonProxyAddr + } + dstPort, err := strconv.Atoi(hdr[5]) + if err != nil { + acceptErr = ErrV1BadSyntax + goto nonProxyAddr + } + ppAddr = &ppAddressTuple{ + transport: trStream, + srcAddr: net.ParseIP(hdr[2]), + dstAddr: net.ParseIP(hdr[3]), + srcPort: uint16(srcPort), + dstPort: uint16(dstPort), + } + default: + acceptErr = ErrUnrecognizedTransportProtocol + goto nonProxyAddr + } + + return &proxyProtocolConn{ + Conn: conn, + reader: reader, + addrLocal: ppAddr.Dst(), + addrProxy: conn.RemoteAddr(), + addrOrigin: ppAddr.Src(), + }, nil + } else { + acceptErr = nil + goto nonProxyAddr + } + +nonProxyAddr: + if acceptErr != nil { + log.Default().WithPrefix("http.ProxyProtocol").Warningf( + "error accepting conn %s->%s with proxy protocol: %v", + conn.RemoteAddr(), conn.LocalAddr(), acceptErr, + ) + + conn.Write([]byte( + fmt.Sprintf( + "unable to accept connection, proxy protocol error encountered: %v", + acceptErr, + ), + )) + + conn.Close() + return nil, acceptErr + } + return &proxyProtocolConn{ + Conn: conn, + reader: reader, + addrLocal: conn.LocalAddr(), + addrProxy: nil, + addrOrigin: conn.RemoteAddr(), + }, nil +} + +func (c *proxyProtocolConn) Read(buf []byte) (int, error) { + return c.reader.Read(buf) +} + +func (c *proxyProtocolConn) LocalAddr() net.Addr { + return c.addrLocal +} + +func (c *proxyProtocolConn) ProxyAddr() net.Addr { + return c.addrProxy +} + +func (c *proxyProtocolConn) RemoteAddr() net.Addr { + return c.addrOrigin +} diff --git a/http/proxy_protocol_test.go b/http/proxy_protocol_test.go new file mode 100644 index 0000000..c755b1b --- /dev/null +++ b/http/proxy_protocol_test.go @@ -0,0 +1,169 @@ +package http + +import ( + "bufio" + "errors" + "fmt" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + ipv6TestAddr1 = net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x44} + ipv6TestAddr2 = net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88} +) + +func mergeByteSeqs(seqs ...[]byte) []byte { + var out []byte + for _, seq := range seqs { + out = append(out, seq...) + } + return out +} + +func padSocket(in string) []byte { + out := make([]byte, v2UnixSocketLength) + copy(out, []byte(in)) + return out +} + +func TestProxyProtocolEncode(t *testing.T) { + type testCase struct { + cmd, af, tr uint8 + addr ppAddress + expectBytes []byte + } + + testCases := []*testCase{ + { + cmdProxy, afInet, trStream, + &ppAddressTuple{trStream, net.IP{10, 0, 0, 1}, net.IP{10, 0, 0, 2}, 1024, 1025}, + []byte{ + 0x21, 0x11, + 0, 12, + 0x0a, 0x00, 0x00, 0x01, + 0x0a, 0x00, 0x00, 0x02, + 0x04, 0x00, + 0x04, 0x01, + }, + }, + { + cmdProxy, afInet6, trStream, + &ppAddressTuple{trStream, ipv6TestAddr1, ipv6TestAddr2, 1024, 1025}, + mergeByteSeqs( + []byte{0x21, 0x21, 0, 36}, + ipv6TestAddr1, + ipv6TestAddr2, + []byte{0x04, 0x00}, + []byte{0x04, 0x01}), + }, + { + cmdProxy, afUnix, trStream, + &ppUnixTuple{trStream, "/run/test/one.sock", "/run/test/two.sock"}, + mergeByteSeqs( + []byte{0x21, 0x31, 0, 216}, + padSocket("/run/test/one.sock"), + padSocket("/run/test/two.sock")), + }, + } + + for _, tc := range testCases { + header := NewProxyProtocolV2Header(tc.cmd, tc.af, tc.tr) + out := header.Encode(tc.addr) + assert.Equal(t, out[len(proxyProtocolMagicV2):], tc.expectBytes) + } +} + +func TestProxyProtocolDecode(t *testing.T) { + type testCase struct { + ppHeader []byte + expectAddr string + } + + listener, err := ListenProxyProtocol("tcp", ":0") + assert.NoError(t, err) + defer listener.Close() + + testCases := []*testCase{ + { + NewProxyProtocolV2Header(cmdProxy, afInet, trStream).Encode( + &ppAddressTuple{trStream, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 2}, 1024, 1025}), + "127.0.0.1:1024->[%s]->127.0.0.2:1025", + }, + { + []byte("PROXY TCP4 127.0.0.1 127.0.0.2 1024 1025\r\n"), + "127.0.0.1:1024->[%s]->127.0.0.2:1025", + }, + { + []byte("PROXY TCP6 fe80::1 fe80::2 1024 1025\r\n"), + "[fe80::1]:1024->[%s]->[fe80::2]:1025", + }, + } + + handle := func(conn net.Conn) { + defer conn.Close() + reader := bufio.NewReader(conn) + pConn := conn.(ProxyProtocolConn) + + for { + lineBytes, _, err := reader.ReadLine() + if errors.Is(err, net.ErrClosed) { + return + } + assert.NoError(t, err) + + line := string(lineBytes) + switch strings.Trim(line, "\r\n") { + case "whoami": + conn.Write([]byte( + fmt.Sprintf("%s->[%+v]->%s\n", conn.RemoteAddr(), pConn.ProxyAddr(), conn.LocalAddr()), + )) + case "echo": + conn.Write([]byte(line + "\n")) + case "bye": + conn.Write([]byte("tallyho\n")) + return + } + } + } + + go (func() { + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + + assert.NoError(t, err) + return + } + + go handle(conn) + } + })() + + for _, tc := range testCases { + conn, err := net.Dial(listener.Addr().Network(), listener.Addr().String()) + assert.NoError(t, err) + reader := bufio.NewReader(conn) + fmt.Printf("%+v\n", tc.ppHeader) + _, err = conn.Write(tc.ppHeader) + assert.NoError(t, err) + + expectAddr := fmt.Sprintf(tc.expectAddr, conn.LocalAddr().String()) + + conn.Write([]byte("whoami\n")) + answer, _, err := reader.ReadLine() + assert.NoError(t, err) + assert.Equal(t, expectAddr, strings.Trim(string(answer), "\r\n")) + conn.Write([]byte("bye\n")) + answer, _, err = reader.ReadLine() + assert.NoError(t, err) + assert.Equal(t, "tallyho", strings.Trim(string(answer), "\r\n")) + conn.Close() + } +} diff --git a/http/route_action_proxy.go b/http/route_action_proxy.go index 83dccae..a39a225 100644 --- a/http/route_action_proxy.go +++ b/http/route_action_proxy.go @@ -85,6 +85,7 @@ func (su *staticUpstreamAction) Handle(w http.ResponseWriter, r *http.Request, n http.Error(w, fmt.Sprintf("error setting up connection to backend: %v", err), http.StatusInternalServerError) + return } response, err := client.Do(upstreamReq) if err != nil { diff --git a/http/server.go b/http/server.go index 9f88be1..41afc7f 100644 --- a/http/server.go +++ b/http/server.go @@ -2,11 +2,13 @@ package http import ( "context" + "crypto/tls" "errors" "fmt" "net" "net/http" "os" + "sync" "go.fuhry.dev/runtime/mtls" "go.fuhry.dev/runtime/utils/log" @@ -30,15 +32,19 @@ type VirtualHost struct { } type Listener struct { - Addr string `yaml:"listen"` - InsecureAddr string `yaml:"listen_insecure"` - Certificate string `yaml:"cert"` - VirtualHosts map[string]*VirtualHost `yaml:"virtual_hosts"` + Addr string `yaml:"listen"` + ProxyProtocol bool `yaml:"proxy_protocol"` + InsecureAddr string `yaml:"listen_insecure"` + Certificate string `yaml:"cert"` + VirtualHosts map[string]*VirtualHost `yaml:"virtual_hosts"` } type Server struct { Listener *Listener `yaml:"listener"` Context context.Context `yaml:"-"` + + tlsServer *http.Server + httpServer *http.Server } type initHook func(context.Context, *yaml.Node) (context.Context, error) @@ -138,14 +144,82 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error { return nil } -func (s *Server) Create() (*http.Server, error) { +func (s *Server) ListenAndServeTLS() error { listenerCtx := context.WithValue(s.Context, kListener, s.Listener) - return s.Listener.NewHTTPServerWithContext(listenerCtx) + server, err := s.Listener.NewHTTPServerWithContext(listenerCtx) + if err != nil { + return err + } + + var listener net.Listener + if s.Listener.ProxyProtocol { + listener, err = ListenProxyProtocol("tcp", server.Addr) + LoggerFromContext(listenerCtx).Noticef( + "Listening for standard or PROXY protocol TLS connections on %s", listener.Addr(), + ) + } else { + listener, err = net.Listen("tcp", server.Addr) + LoggerFromContext(listenerCtx).Noticef( + "Listening for TLS connnections on %s", listener.Addr(), + ) + } + if err != nil { + return err + } + if server.TLSConfig != nil { + listener = tls.NewListener(listener, server.TLSConfig) + } + + s.tlsServer = server + return server.Serve(listener) } -func (s *Server) CreateInsecure() *http.Server { +func (s *Server) ListenAndServe() error { + var listener net.Listener + var err error + listenerCtx := context.WithValue(s.Context, kListener, s.Listener) - return s.Listener.NewHTTPSRedirectorWithContext(listenerCtx) + server := s.Listener.NewHTTPSRedirectorWithContext(listenerCtx) + if s.Listener.ProxyProtocol { + listener, err = ListenProxyProtocol("tcp", server.Addr) + LoggerFromContext(listenerCtx).Noticef( + "Listening for standard or PROXY protocol connections on %s", listener.Addr(), + ) + } else { + listener, err = net.Listen("tcp", server.Addr) + LoggerFromContext(listenerCtx).Noticef( + "Listening on %s", listener.Addr(), + ) + } + if err != nil { + return err + } + + s.httpServer = server + return server.Serve(listener) +} + +func (s *Server) Shutdown(shutdownCtx context.Context) error { + var wg sync.WaitGroup + var err error + if s.httpServer != nil { + wg.Add(1) + + go (func() { + defer wg.Done() + err = s.httpServer.Shutdown(shutdownCtx) + })() + } + if s.tlsServer != nil { + wg.Add(1) + go (func() { + defer wg.Done() + err = s.tlsServer.Shutdown(shutdownCtx) + })() + } + + wg.Wait() + return err } // NewHTTPServerWithContext creates an http.Server using the proxy's virtual host @@ -160,7 +234,7 @@ func (l *Listener) NewHTTPServerWithContext(ctx context.Context) (*http.Server, lm := log.NewLoggingMiddlewareWithLogger( http.HandlerFunc(l.handle), - logger.AppendPrefix("access")) + logger.AppendPrefix(".access")) server := &http.Server{ Addr: l.Addr,