--- /dev/null
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "net"
+ "strconv"
+ "strings"
+
+ "go.fuhry.dev/runtime/utils/log"
+)
+
+type ProxyProtocolConn interface {
+ net.Conn
+
+ ProxyAddr() net.Addr
+}
+
+type proxyProtocolListener struct {
+ net.Listener
+}
+
+type proxyProtocolConn struct {
+ net.Conn
+
+ reader *bufio.Reader
+ addrOrigin net.Addr
+ addrProxy net.Addr
+ addrLocal net.Addr
+}
+
+type ppVersionCommand uint8
+type ppFamily uint8
+
+type ppHeader struct {
+ vc ppVersionCommand
+ fa ppFamily
+ len uint16
+}
+
+type ppAddressTuple struct {
+ transport uint8
+
+ srcAddr net.IP
+ dstAddr net.IP
+ srcPort uint16
+ dstPort uint16
+}
+
+type ppUnixTuple struct {
+ transport uint8
+
+ src string
+ dst string
+}
+
+type ppAddress interface {
+ Src() net.Addr
+ Dst() net.Addr
+ Encode() []byte
+}
+
+func (a *ppAddressTuple) Src() net.Addr {
+ switch a.transport {
+ case trDatagram:
+ return &net.UDPAddr{
+ IP: a.srcAddr,
+ Port: int(a.srcPort),
+ }
+ case trStream:
+ return &net.TCPAddr{
+ IP: a.srcAddr,
+ Port: int(a.srcPort),
+ }
+ default:
+ return nil
+ }
+}
+
+func (a *ppAddressTuple) Dst() net.Addr {
+ return &net.TCPAddr{
+ IP: a.dstAddr,
+ Port: int(a.dstPort),
+ }
+}
+
+func (a *ppAddressTuple) Encode() []byte {
+ var buf []byte
+ buf = append(buf, a.srcAddr...)
+ buf = append(buf, a.dstAddr...)
+ buf = append(buf, uint8(a.srcPort>>8))
+ buf = append(buf, uint8(a.srcPort))
+ buf = append(buf, uint8(a.dstPort>>8))
+ buf = append(buf, uint8(a.dstPort))
+
+ return buf
+}
+
+func (a *ppUnixTuple) toAddr(s string) net.Addr {
+ switch a.transport {
+ case trDatagram:
+ return &net.UnixAddr{
+ Name: s,
+ Net: "unixgram",
+ }
+ case trStream:
+ return &net.UnixAddr{
+ Name: s,
+ Net: "unix",
+ }
+ default:
+ return nil
+ }
+}
+
+func (a *ppUnixTuple) Src() net.Addr {
+ return a.toAddr(a.src)
+}
+
+func (a *ppUnixTuple) Dst() net.Addr {
+ return a.toAddr(a.dst)
+}
+
+func (a *ppUnixTuple) Encode() []byte {
+ buf := make([]byte, v2UnixSocketLength*2)
+ copy(buf, []byte(a.src))
+ copy(buf[v2UnixSocketLength:], []byte(a.dst))
+
+ return buf
+}
+
+func newPPFamily(af, tr uint8) ppFamily {
+ return ppFamily(af<<4 | tr)
+}
+
+func (f ppFamily) String() string {
+ var ret string
+ switch f.Transport() {
+ case trStream:
+ ret = "TCP"
+ case trDatagram:
+ ret = "UDP"
+ default:
+ return "UNKNOWN"
+ }
+
+ switch f.AddressFamily() {
+ case afInet:
+ ret += "4"
+ case afInet6:
+ ret += "6"
+ default:
+ return "UNKNOWN"
+ }
+ return ret
+}
+
+const (
+ v2HeaderLength = 16
+ v2UnixSocketLength = 108
+
+ version2 uint8 = 2
+
+ cmdLocal uint8 = 0
+ cmdProxy uint8 = 1
+
+ afUnspec uint8 = 0
+ afInet uint8 = 1
+ afInet6 uint8 = 2
+ afUnix uint8 = 3
+
+ trUnspec uint8 = 0
+ trStream uint8 = 1
+ trDatagram uint8 = 2
+)
+
+var proxyProtocolHeaderV1 = [6]byte{
+ 'P', 'R', 'O', 'X', 'Y', ' ',
+}
+var proxyProtocolMagicV2 = [12]byte{
+ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 'Q', 'U', 'I', 'T', 0x0a,
+}
+var ErrBadMagic = errors.New("bad magic")
+var ErrV1BadSyntax = errors.New("v1 header is incorrectly formatted")
+var ErrLength = errors.New("header is wrong length")
+var ErrUnrecognizedVersion = errors.New("unrecognized protocol version")
+var ErrUnrecognizedCommand = errors.New("unrecognized command")
+var ErrUnspecifiedAddressFamily = errors.New("unspecified address family")
+var ErrUnrecognizedAddressFamily = errors.New("unrecognized address family")
+var ErrUnrecognizedTransportProtocol = errors.New("unrecognized transport protocol")
+
+func (v ppVersionCommand) Version() uint8 {
+ return uint8(v) >> 4
+}
+
+func (v ppVersionCommand) Command() uint8 {
+ return uint8(v) & 0x0f
+}
+
+func (v ppFamily) AddressFamily() uint8 {
+ return uint8(v) >> 4
+}
+
+func (v ppFamily) Transport() uint8 {
+ return uint8(v) & 0x0f
+}
+
+func (v ppHeader) ReadAddress(r *bufio.Reader) (ppAddress, error) {
+ addressBuf := make([]byte, int(v.len))
+ n, err := r.Read(addressBuf)
+ if err != nil {
+ return nil, err
+ }
+ if n != int(v.len) {
+ return nil, fmt.Errorf("got wrong number of bytes reading address, expected %d got %d, buf: %+v", v.len, n, addressBuf) // ErrLength
+ }
+
+ var expectLength int
+ switch v.fa.AddressFamily() {
+ case afUnspec:
+ return nil, ErrUnspecifiedAddressFamily
+ case afInet:
+ expectLength = 12 // two IPv4 addresses + 2 ports
+ case afInet6:
+ expectLength = 36 // two IPv6 addresses + 2 ports
+ case afUnix:
+ expectLength = 2 * v2UnixSocketLength // two UNIX socket addresses
+ default:
+ panic("not reached")
+ }
+
+ if n != expectLength {
+ return nil, fmt.Errorf("proxy protocol message is wrong length for af %d, proto %d: expected %d, got %d, buf: %+v",
+ v.fa.AddressFamily(), v.fa.Transport(), expectLength, n, addressBuf)
+ }
+
+ switch v.fa.AddressFamily() {
+ case afInet:
+ return &ppAddressTuple{
+ transport: v.fa.Transport(),
+ srcAddr: addressBuf[0:4],
+ dstAddr: addressBuf[4:8],
+ srcPort: uint16(addressBuf[8])<<8 | uint16(addressBuf[9]),
+ dstPort: uint16(addressBuf[10])<<8 | uint16(addressBuf[11]),
+ }, nil
+ case afInet6:
+ return &ppAddressTuple{
+ transport: v.fa.Transport(),
+ srcAddr: addressBuf[0:16],
+ dstAddr: addressBuf[16:32],
+ srcPort: uint16(addressBuf[32])<<8 | uint16(addressBuf[33]),
+ dstPort: uint16(addressBuf[34])<<8 | uint16(addressBuf[35]),
+ }, nil
+ case afUnix:
+ return &ppUnixTuple{
+ src: strings.Trim(string(addressBuf[0:v2UnixSocketLength]), "\x00"),
+ dst: strings.Trim(string(addressBuf[v2UnixSocketLength:]), "\x00"),
+ }, nil
+ default:
+ panic("not reached")
+ }
+}
+
+func NewProxyProtocolV2Header(cmd, af, tr uint8) ppHeader {
+ return ppHeader{
+ vc: ppVersionCommand(version2<<4 | cmd),
+ fa: newPPFamily(af, tr),
+ }
+}
+
+func (v ppHeader) Encode(addr ppAddress) []byte {
+ out := make([]byte, v2HeaderLength)
+ copy(out, proxyProtocolMagicV2[:])
+
+ out[12] = uint8(v.vc)
+ out[13] = uint8(v.fa)
+ addrBytes := addr.Encode()
+ out[14] = uint8(len(addrBytes) >> 8)
+ out[15] = uint8(len(addrBytes))
+ out = append(out, addrBytes...)
+
+ return out
+}
+
+func parseProxyProtocolV2Header(header []byte) (*ppHeader, error) {
+ if len(header) != v2HeaderLength {
+ return nil, ErrLength
+ }
+
+ if !bytes.Equal(header[:len(proxyProtocolMagicV2)], proxyProtocolMagicV2[:]) {
+ return nil, ErrBadMagic
+ }
+
+ vc := ppVersionCommand(header[12])
+ fa := ppFamily(header[13])
+
+ if vc.Version() != version2 {
+ return nil, ErrUnrecognizedVersion
+ }
+
+ if vc.Command() != cmdLocal && vc.Command() != cmdProxy {
+ return nil, ErrUnrecognizedCommand
+ }
+
+ af := fa.AddressFamily()
+ if af != afUnspec && af != afInet && af != afInet6 && af != afUnix {
+ return nil, ErrUnrecognizedAddressFamily
+ }
+
+ pr := fa.Transport()
+ if pr != trUnspec && pr != trStream && pr != trDatagram {
+ return nil, ErrUnrecognizedTransportProtocol
+ }
+
+ return &ppHeader{
+ vc: vc,
+ fa: fa,
+ // length is a uint16 in network (big endian) order
+ len: uint16(header[14])<<8 | uint16(header[15]),
+ }, nil
+}
+
+// PROXY protocol v2 header:
+// offset len type desc
+// 0 12 []byte magic
+// 12 1 uint8 version
+// 13 1 uint8 family
+// 14 2 uint16 length
+// [union] tcp/udp over IPv4 (overall 12 bytes)
+// 16 4 uint32 src addr
+// 20 4 uint32 dst addr
+// 24 2 uint16 src port
+// 26 2 uint16 dst port
+// [union] tcp/udp over IPv6 (overall 36 bytes)
+// 16 16 []uint8 src addr
+// 32 16 []uint8 dst addr
+// 48 2 uint16 src port
+// 50 2 uint16 dst port
+// [union] unix socket (overall 216 bytes)
+// 16 108 []uint8 src addr
+// 124 108 []uint8 dst addr
+
+func ListenProxyProtocol(network, addr string) (net.Listener, error) {
+ l, err := net.Listen(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ return &proxyProtocolListener{l}, nil
+}
+
+func (p *proxyProtocolListener) Accept() (net.Conn, error) {
+ conn, err := p.Listener.Accept()
+ if err != nil {
+ return nil, err
+ }
+
+ return newProxyProtocolConn(conn)
+}
+
+func newProxyProtocolConn(conn net.Conn) (net.Conn, error) {
+ var acceptErr error
+ reader := bufio.NewReader(conn)
+ header, err := reader.Peek(len(proxyProtocolMagicV2))
+ if err != nil {
+ acceptErr = err
+ goto nonProxyAddr
+ }
+
+ if bytes.Equal(header[:len(proxyProtocolMagicV2)], proxyProtocolMagicV2[:]) {
+ hdr := make([]byte, v2HeaderLength)
+ n, err := reader.Read(hdr)
+ if err != nil || n != v2HeaderLength {
+ acceptErr = err
+ goto nonProxyAddr
+ }
+ parsedHeader, err := parseProxyProtocolV2Header(hdr)
+ if err != nil {
+ acceptErr = err
+ goto nonProxyAddr
+ }
+ addr, err := parsedHeader.ReadAddress(reader)
+ if err != nil {
+ acceptErr = err
+ goto nonProxyAddr
+ }
+ return &proxyProtocolConn{
+ Conn: conn,
+ reader: reader,
+ addrLocal: addr.Dst(),
+ addrProxy: conn.RemoteAddr(),
+ addrOrigin: addr.Src(),
+ }, nil
+ } else if bytes.Equal(header[:len(proxyProtocolHeaderV1)], proxyProtocolHeaderV1[:]) {
+ buf, err := reader.ReadBytes('\n')
+ if err != nil {
+ acceptErr = ErrV1BadSyntax
+ goto nonProxyAddr
+ }
+ hdr := strings.Split(strings.Trim(string(buf), "\r\n"), " ")
+ if len(hdr) != 6 {
+ acceptErr = ErrLength
+ goto nonProxyAddr
+ }
+
+ var ppAddr ppAddress
+ switch hdr[1] {
+ case newPPFamily(afInet, trStream).String(), newPPFamily(afInet6, trStream).String():
+ srcPort, err := strconv.Atoi(hdr[4])
+ if err != nil {
+ acceptErr = ErrV1BadSyntax
+ goto nonProxyAddr
+ }
+ dstPort, err := strconv.Atoi(hdr[5])
+ if err != nil {
+ acceptErr = ErrV1BadSyntax
+ goto nonProxyAddr
+ }
+ ppAddr = &ppAddressTuple{
+ transport: trStream,
+ srcAddr: net.ParseIP(hdr[2]),
+ dstAddr: net.ParseIP(hdr[3]),
+ srcPort: uint16(srcPort),
+ dstPort: uint16(dstPort),
+ }
+ default:
+ acceptErr = ErrUnrecognizedTransportProtocol
+ goto nonProxyAddr
+ }
+
+ return &proxyProtocolConn{
+ Conn: conn,
+ reader: reader,
+ addrLocal: ppAddr.Dst(),
+ addrProxy: conn.RemoteAddr(),
+ addrOrigin: ppAddr.Src(),
+ }, nil
+ } else {
+ acceptErr = nil
+ goto nonProxyAddr
+ }
+
+nonProxyAddr:
+ if acceptErr != nil {
+ log.Default().WithPrefix("http.ProxyProtocol").Warningf(
+ "error accepting conn %s->%s with proxy protocol: %v",
+ conn.RemoteAddr(), conn.LocalAddr(), acceptErr,
+ )
+
+ conn.Write([]byte(
+ fmt.Sprintf(
+ "unable to accept connection, proxy protocol error encountered: %v",
+ acceptErr,
+ ),
+ ))
+
+ conn.Close()
+ return nil, acceptErr
+ }
+ return &proxyProtocolConn{
+ Conn: conn,
+ reader: reader,
+ addrLocal: conn.LocalAddr(),
+ addrProxy: nil,
+ addrOrigin: conn.RemoteAddr(),
+ }, nil
+}
+
+func (c *proxyProtocolConn) Read(buf []byte) (int, error) {
+ return c.reader.Read(buf)
+}
+
+func (c *proxyProtocolConn) LocalAddr() net.Addr {
+ return c.addrLocal
+}
+
+func (c *proxyProtocolConn) ProxyAddr() net.Addr {
+ return c.addrProxy
+}
+
+func (c *proxyProtocolConn) RemoteAddr() net.Addr {
+ return c.addrOrigin
+}
--- /dev/null
+package http
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+var (
+ ipv6TestAddr1 = net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x44}
+ ipv6TestAddr2 = net.IP{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88}
+)
+
+func mergeByteSeqs(seqs ...[]byte) []byte {
+ var out []byte
+ for _, seq := range seqs {
+ out = append(out, seq...)
+ }
+ return out
+}
+
+func padSocket(in string) []byte {
+ out := make([]byte, v2UnixSocketLength)
+ copy(out, []byte(in))
+ return out
+}
+
+func TestProxyProtocolEncode(t *testing.T) {
+ type testCase struct {
+ cmd, af, tr uint8
+ addr ppAddress
+ expectBytes []byte
+ }
+
+ testCases := []*testCase{
+ {
+ cmdProxy, afInet, trStream,
+ &ppAddressTuple{trStream, net.IP{10, 0, 0, 1}, net.IP{10, 0, 0, 2}, 1024, 1025},
+ []byte{
+ 0x21, 0x11,
+ 0, 12,
+ 0x0a, 0x00, 0x00, 0x01,
+ 0x0a, 0x00, 0x00, 0x02,
+ 0x04, 0x00,
+ 0x04, 0x01,
+ },
+ },
+ {
+ cmdProxy, afInet6, trStream,
+ &ppAddressTuple{trStream, ipv6TestAddr1, ipv6TestAddr2, 1024, 1025},
+ mergeByteSeqs(
+ []byte{0x21, 0x21, 0, 36},
+ ipv6TestAddr1,
+ ipv6TestAddr2,
+ []byte{0x04, 0x00},
+ []byte{0x04, 0x01}),
+ },
+ {
+ cmdProxy, afUnix, trStream,
+ &ppUnixTuple{trStream, "/run/test/one.sock", "/run/test/two.sock"},
+ mergeByteSeqs(
+ []byte{0x21, 0x31, 0, 216},
+ padSocket("/run/test/one.sock"),
+ padSocket("/run/test/two.sock")),
+ },
+ }
+
+ for _, tc := range testCases {
+ header := NewProxyProtocolV2Header(tc.cmd, tc.af, tc.tr)
+ out := header.Encode(tc.addr)
+ assert.Equal(t, out[len(proxyProtocolMagicV2):], tc.expectBytes)
+ }
+}
+
+func TestProxyProtocolDecode(t *testing.T) {
+ type testCase struct {
+ ppHeader []byte
+ expectAddr string
+ }
+
+ listener, err := ListenProxyProtocol("tcp", ":0")
+ assert.NoError(t, err)
+ defer listener.Close()
+
+ testCases := []*testCase{
+ {
+ NewProxyProtocolV2Header(cmdProxy, afInet, trStream).Encode(
+ &ppAddressTuple{trStream, net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 2}, 1024, 1025}),
+ "127.0.0.1:1024->[%s]->127.0.0.2:1025",
+ },
+ {
+ []byte("PROXY TCP4 127.0.0.1 127.0.0.2 1024 1025\r\n"),
+ "127.0.0.1:1024->[%s]->127.0.0.2:1025",
+ },
+ {
+ []byte("PROXY TCP6 fe80::1 fe80::2 1024 1025\r\n"),
+ "[fe80::1]:1024->[%s]->[fe80::2]:1025",
+ },
+ }
+
+ handle := func(conn net.Conn) {
+ defer conn.Close()
+ reader := bufio.NewReader(conn)
+ pConn := conn.(ProxyProtocolConn)
+
+ for {
+ lineBytes, _, err := reader.ReadLine()
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+ assert.NoError(t, err)
+
+ line := string(lineBytes)
+ switch strings.Trim(line, "\r\n") {
+ case "whoami":
+ conn.Write([]byte(
+ fmt.Sprintf("%s->[%+v]->%s\n", conn.RemoteAddr(), pConn.ProxyAddr(), conn.LocalAddr()),
+ ))
+ case "echo":
+ conn.Write([]byte(line + "\n"))
+ case "bye":
+ conn.Write([]byte("tallyho\n"))
+ return
+ }
+ }
+ }
+
+ go (func() {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+
+ assert.NoError(t, err)
+ return
+ }
+
+ go handle(conn)
+ }
+ })()
+
+ for _, tc := range testCases {
+ conn, err := net.Dial(listener.Addr().Network(), listener.Addr().String())
+ assert.NoError(t, err)
+ reader := bufio.NewReader(conn)
+ fmt.Printf("%+v\n", tc.ppHeader)
+ _, err = conn.Write(tc.ppHeader)
+ assert.NoError(t, err)
+
+ expectAddr := fmt.Sprintf(tc.expectAddr, conn.LocalAddr().String())
+
+ conn.Write([]byte("whoami\n"))
+ answer, _, err := reader.ReadLine()
+ assert.NoError(t, err)
+ assert.Equal(t, expectAddr, strings.Trim(string(answer), "\r\n"))
+ conn.Write([]byte("bye\n"))
+ answer, _, err = reader.ReadLine()
+ assert.NoError(t, err)
+ assert.Equal(t, "tallyho", strings.Trim(string(answer), "\r\n"))
+ conn.Close()
+ }
+}
import (
"context"
+ "crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
+ "sync"
"go.fuhry.dev/runtime/mtls"
"go.fuhry.dev/runtime/utils/log"
}
type Listener struct {
- Addr string `yaml:"listen"`
- InsecureAddr string `yaml:"listen_insecure"`
- Certificate string `yaml:"cert"`
- VirtualHosts map[string]*VirtualHost `yaml:"virtual_hosts"`
+ Addr string `yaml:"listen"`
+ ProxyProtocol bool `yaml:"proxy_protocol"`
+ InsecureAddr string `yaml:"listen_insecure"`
+ Certificate string `yaml:"cert"`
+ VirtualHosts map[string]*VirtualHost `yaml:"virtual_hosts"`
}
type Server struct {
Listener *Listener `yaml:"listener"`
Context context.Context `yaml:"-"`
+
+ tlsServer *http.Server
+ httpServer *http.Server
}
type initHook func(context.Context, *yaml.Node) (context.Context, error)
return nil
}
-func (s *Server) Create() (*http.Server, error) {
+func (s *Server) ListenAndServeTLS() error {
listenerCtx := context.WithValue(s.Context, kListener, s.Listener)
- return s.Listener.NewHTTPServerWithContext(listenerCtx)
+ server, err := s.Listener.NewHTTPServerWithContext(listenerCtx)
+ if err != nil {
+ return err
+ }
+
+ var listener net.Listener
+ if s.Listener.ProxyProtocol {
+ listener, err = ListenProxyProtocol("tcp", server.Addr)
+ LoggerFromContext(listenerCtx).Noticef(
+ "Listening for standard or PROXY protocol TLS connections on %s", listener.Addr(),
+ )
+ } else {
+ listener, err = net.Listen("tcp", server.Addr)
+ LoggerFromContext(listenerCtx).Noticef(
+ "Listening for TLS connnections on %s", listener.Addr(),
+ )
+ }
+ if err != nil {
+ return err
+ }
+ if server.TLSConfig != nil {
+ listener = tls.NewListener(listener, server.TLSConfig)
+ }
+
+ s.tlsServer = server
+ return server.Serve(listener)
}
-func (s *Server) CreateInsecure() *http.Server {
+func (s *Server) ListenAndServe() error {
+ var listener net.Listener
+ var err error
+
listenerCtx := context.WithValue(s.Context, kListener, s.Listener)
- return s.Listener.NewHTTPSRedirectorWithContext(listenerCtx)
+ server := s.Listener.NewHTTPSRedirectorWithContext(listenerCtx)
+ if s.Listener.ProxyProtocol {
+ listener, err = ListenProxyProtocol("tcp", server.Addr)
+ LoggerFromContext(listenerCtx).Noticef(
+ "Listening for standard or PROXY protocol connections on %s", listener.Addr(),
+ )
+ } else {
+ listener, err = net.Listen("tcp", server.Addr)
+ LoggerFromContext(listenerCtx).Noticef(
+ "Listening on %s", listener.Addr(),
+ )
+ }
+ if err != nil {
+ return err
+ }
+
+ s.httpServer = server
+ return server.Serve(listener)
+}
+
+func (s *Server) Shutdown(shutdownCtx context.Context) error {
+ var wg sync.WaitGroup
+ var err error
+ if s.httpServer != nil {
+ wg.Add(1)
+
+ go (func() {
+ defer wg.Done()
+ err = s.httpServer.Shutdown(shutdownCtx)
+ })()
+ }
+ if s.tlsServer != nil {
+ wg.Add(1)
+ go (func() {
+ defer wg.Done()
+ err = s.tlsServer.Shutdown(shutdownCtx)
+ })()
+ }
+
+ wg.Wait()
+ return err
}
// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host
lm := log.NewLoggingMiddlewareWithLogger(
http.HandlerFunc(l.handle),
- logger.AppendPrefix("access"))
+ logger.AppendPrefix(".access"))
server := &http.Server{
Addr: l.Addr,