From: Dan Fuhry Date: Thu, 6 Nov 2025 11:49:53 +0000 (-0500) Subject: [http] live reload, SNI proxying, bugfixes X-Git-Url: https://go.fuhry.dev/?a=commitdiff_plain;h=33d9311cc439bb8f500ff44c9a194920706d0504;p=runtime.git [http] live reload, SNI proxying, bugfixes - Initial backend live reload support, to be enabled in binary with merge of ephs/config_watcher - Fix SNI listener closure upon single connection TLS handshake error --- diff --git a/http/http.go b/http/http.go index 1dfb1f3..b65b019 100644 --- a/http/http.go +++ b/http/http.go @@ -2,6 +2,7 @@ package http import ( "context" + "crypto/tls" "errors" "fmt" "net" @@ -18,7 +19,7 @@ type HTTPVirtualHost struct { *BaseVirtualHost `yaml:",inline"` Routes []*Route `yaml:"routes"` - Certificate string `yaml:"certificate"` + Certificate string `yaml:"cert"` TrustUpstreamRequestID *bool `yaml:"trust_upstream_request_id"` tlsServer *http.Server @@ -77,12 +78,13 @@ func (v *HTTPVirtualHost) NewHTTPServerWithContext(ctx context.Context) (*http.S return server, nil } -func (v *HTTPVirtualHost) Shutdown(ctx context.Context) error { +func (v *HTTPVirtualHost) Shutdown(ctx context.Context) (err error) { if v.tlsServer != nil { - return v.tlsServer.Shutdown(ctx) + err = v.tlsServer.Shutdown(ctx) + v.tlsServer = nil } - return nil + return } func (v *HTTPVirtualHost) handle(w http.ResponseWriter, r *http.Request) { @@ -109,6 +111,10 @@ func (v *HTTPVirtualHost) handle(w http.ResponseWriter, r *http.Request) { v.fulfill(w, r, v.Routes) } +func (v *HTTPVirtualHost) swapTlsConfig(c *tls.Config) { + v.tlsServer.TLSConfig = c +} + func vhostGlobToRegexp(name string) (*regexp.Regexp, error) { var parts []string for _, part := range strings.Split(name, ".") { diff --git a/http/route_action_s3.go b/http/route_action_s3.go index 4e5b6ee..40dc3a1 100644 --- a/http/route_action_s3.go +++ b/http/route_action_s3.go @@ -71,7 +71,7 @@ func (a *S3Action) Handle(w http.ResponseWriter, r *http.Request, next http.Hand stat, err := object.Stat() if err != nil { er := minio.ToErrorResponse(err) - w.WriteHeader(er.StatusCode) + w.WriteHeader(coalesce(er.StatusCode, http.StatusInternalServerError)) w.Write([]byte(fmt.Sprintf("failed to stat object %q: %+v", objPath, err))) return } @@ -188,6 +188,16 @@ func s3ActionFromRouteYaml(node *yaml.Node) (RouteAction, error) { return nil, nil } +func coalesce[T comparable](items ...T) T { + var empty T + for _, last := range items { + if last != empty { + return last + } + } + return empty +} + func init() { AddRouteParseFunc(s3ActionFromRouteYaml) } diff --git a/http/server.go b/http/server.go index 1194124..18ae75f 100644 --- a/http/server.go +++ b/http/server.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "go.fuhry.dev/runtime/mtls" "go.fuhry.dev/runtime/utils/hashset" "go.fuhry.dev/runtime/utils/log" "go.fuhry.dev/runtime/utils/stringmatch" @@ -34,6 +35,7 @@ type Route struct { type NameMatcher interface { MatchesName(string) bool + Names() []string } type VirtualHost interface { @@ -56,21 +58,24 @@ type Listener struct { // ExternalPort is an optional port to use when sending http-to-https redirects. If external traffic // enters through a different port than what the server is bound to locally, this setting ensures that // redirects reference the correct port. - ExternalPort int `yaml:"external_port"` - ProxyProtocol bool `yaml:"proxy_protocol"` - InsecureAddr string `yaml:"listen_insecure"` - Certificate string `yaml:"cert"` - TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"` - VirtualHosts []VirtualHost `yaml:"-"` - HTTPVirtualHosts map[string]*HTTPVirtualHost `yaml:"virtual_hosts"` + ExternalPort int `yaml:"external_port"` + ProxyProtocol bool `yaml:"proxy_protocol"` + InsecureAddr string `yaml:"listen_insecure"` + Certificate string `yaml:"cert"` + TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"` + DrainTime time.Duration `yaml:"drain_time"` + VirtualHosts []VirtualHost `yaml:"-"` sniServer *SNIListener httpServer *http.Server + wg sync.WaitGroup } type Server struct { Listeners []*Listener `yaml:"listeners"` Context context.Context `yaml:"-"` + + servingTLS, servingHTTP bool } type Authorization struct { @@ -99,6 +104,10 @@ const ( hostnameWildcardMulti = `[A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)*` ) +const ( + DefaultDrainTime = 30 * time.Second +) + var ErrServerClosed = http.ErrServerClosed var randSrc = rand.New(rand.NewSource(time.Now().UnixNano())) var portSpec = regexp.MustCompile(":[0-9]{1,5}$") @@ -160,6 +169,100 @@ func (r *Route) UnmarshalYAML(node *yaml.Node) error { return nil } +func (l *Listener) UnmarshalYAML(node *yaml.Node) error { + var rawNode struct { + Addr string `yaml:"listen"` + ExternalPort int `yaml:"external_port"` + ProxyProtocol bool `yaml:"proxy_protocol"` + InsecureAddr string `yaml:"listen_insecure"` + Certificate string `yaml:"cert"` + TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"` + DrainTime string `yaml:"drain_time"` + } + + var httpNode struct { + HTTPVirtualHosts map[string]*HTTPVirtualHost `yaml:"virtual_hosts"` + } + + var tcpNode struct { + SNIVirtualHosts map[string]*SNIVirtualHost `yaml:"virtual_hosts"` + } + + if err := node.Decode(&rawNode); err != nil { + return err + } + + if err := node.Decode(&httpNode); err != nil { + return err + } + + if err := node.Decode(&tcpNode); err != nil { + return err + } + + l.Addr = rawNode.Addr + l.ExternalPort = rawNode.ExternalPort + l.ProxyProtocol = rawNode.ProxyProtocol + l.InsecureAddr = rawNode.InsecureAddr + l.Certificate = rawNode.Certificate + l.TrustUpstreamRequestID = rawNode.TrustUpstreamRequestID + + if rawNode.DrainTime != "" { + if dur, err := time.ParseDuration(rawNode.DrainTime); err == nil { + l.DrainTime = dur + } else { + return fmt.Errorf("failed to parse drain_time: %v", err) + } + } else { + l.DrainTime = DefaultDrainTime + } + + coveredVhosts := hashset.NewHashSet[string]() + + if httpNode.HTTPVirtualHosts != nil { + for name, vhost := range httpNode.HTTPVirtualHosts { + if vhost.Routes == nil { + continue + } + if vhost.BaseVirtualHost == nil { + vhost.BaseVirtualHost = &BaseVirtualHost{} + } + coveredVhosts.Add(name) + vhost.Aliases = append(vhost.Aliases, name) + l.VirtualHosts = append(l.VirtualHosts, vhost) + } + } + if tcpNode.SNIVirtualHosts != nil { + for name, vhost := range tcpNode.SNIVirtualHosts { + if vhost.Backend == nil { + continue + } + + if vhost.BaseVirtualHost == nil { + vhost.BaseVirtualHost = &BaseVirtualHost{} + } + + if coveredVhosts.Contains(name) { + return fmt.Errorf( + "virtual host %q contains both http routes and a SNI backend, this is not allowed", + name) + } + + coveredVhosts.Add(name) + vhost.Aliases = append(vhost.Aliases, name) + l.VirtualHosts = append(l.VirtualHosts, vhost) + } + } + + if coveredVhosts.Len() < len(httpNode.HTTPVirtualHosts) { + return fmt.Errorf( + "%d virtual hosts are unconfigured, specify either routes for HTTP or tcp_backend "+ + "for SNI proxying", len(httpNode.HTTPVirtualHosts)-coveredVhosts.Len()) + } + + return nil +} + // UnmarshalYAML implements yaml.Unmarshaler func (s *Server) UnmarshalYAML(node *yaml.Node) error { var lc struct { @@ -183,17 +286,7 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error { lc.Listeners = []*Listener{lc.Listener} } - s.Listeners = lc.Listeners - - for _, listener := range s.Listeners { - for name, vhost := range listener.HTTPVirtualHosts { - if vhost.BaseVirtualHost == nil { - vhost.BaseVirtualHost = &BaseVirtualHost{} - } - vhost.Aliases = append(vhost.Aliases, name) - listener.VirtualHosts = append(listener.VirtualHosts, vhost) - } - + for _, listener := range lc.Listeners { for i, vhost := range listener.VirtualHosts { switch vhost := vhost.(type) { case *HTTPVirtualHost: @@ -204,14 +297,163 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error { vhost.TrustUpstreamRequestID = &listener.TrustUpstreamRequestID } if err := vhost.precompileAliases(); err != nil { - return fmt.Errorf("bootstrapping vhost %d: %v", i, err) + return fmt.Errorf("bootstrapping http vhost %d: %v", i, err) + } + log.Default().Debugf("vhost %+v cert: %s", vhost.Names(), vhost.Certificate) + case *SNIVirtualHost: + if err := vhost.precompileAliases(); err != nil { + return fmt.Errorf("bootstrapping sni vhost %d: %v", i, err) } + log.Default().Debugf("sni vhost %+v bootatrapped", vhost.Names()) default: return fmt.Errorf("unsupported vhost type: %T", vhost) } } } + if len(s.Listeners) > 0 && (s.servingHTTP || s.servingTLS) { + LoggerFromContext(s.Context).Noticef("attempting hot reload on %d listeners", len(lc.Listeners)) + if len(s.Listeners) != len(lc.Listeners) { + return fmt.Errorf("listener count mismatch (%d != %d), refusing hot reload", len(s.Listeners), len(lc.Listeners)) + } + + for i := range s.Listeners { + listenerCtx := context.WithValue(s.Context, kListener, s.Listeners[i]) + + oldL := s.Listeners[i] + newL := lc.Listeners[i] + + if oldL.Addr != newL.Addr { + return fmt.Errorf("hot reload refused: listener %d: modifying address/port is not supported", i) + } + + if oldL.InsecureAddr != newL.InsecureAddr { + return fmt.Errorf("hot reload refused: listener %d: modifying insecure address/port is not supported", i) + } + + if oldL.ProxyProtocol != newL.ProxyProtocol { + return fmt.Errorf("hot reload refused: listener %d: toggling proxy protocol is not supported", i) + } + + oldL.TrustUpstreamRequestID = newL.TrustUpstreamRequestID + oldL.ExternalPort = newL.ExternalPort + oldL.DrainTime = newL.DrainTime + oldL.Certificate = newL.Certificate + + // virtual hosts that need to be started or closed + var ( + // virtual hosts to close are tracked by index, so we can remove them from the + // l.VirtualHosts slice. + closeVhosts []int + // keep track of which vhosts corresponded to those that were already serving, so we + // don't try to start them again + matchedVhosts = hashset.NewHashSet[VirtualHost]() + ) + for vi, vh := range oldL.VirtualHosts { + var match VirtualHost + for _, nvh := range newL.VirtualHosts { + if virtualHostsEqual(vh, nvh) { + match = nvh + matchedVhosts.Add(nvh) + break + } + } + + if match == nil { + // none of the vhosts in the new config match the current one, which means this + // vhost is going away + LoggerFromContext(s.Context).Debugf( + "listener %d: vhost(%T, %+v) is going away", + i, vh, vh.Names(), + ) + closeVhosts = append(closeVhosts, vi) + continue + } + + switch oldVh := vh.(type) { + case *HTTPVirtualHost: + newVh := match.(*HTTPVirtualHost) + + oldVh.Routes = newVh.Routes + oldVh.TrustUpstreamRequestID = newVh.TrustUpstreamRequestID + if oldVh.Certificate != newVh.Certificate { + newCert := mtls.NewSSLCertificate(newVh.Certificate) + if tlsConfig, err := newCert.TlsConfig(listenerCtx); err == nil { + oldVh.swapTlsConfig(tlsConfig) + } else { + LoggerFromContext(listenerCtx).Errorf( + "vhost %+v: error loading new ssl certificate %q: %v", + oldVh.Names(), newVh.Certificate, err) + } + } + LoggerFromContext(listenerCtx).Debugf( + "loaded %d new routes for HTTP vhost %+v", + len(oldVh.Routes), oldVh.Names()) + case *SNIVirtualHost: + newVh := match.(*SNIVirtualHost) + oldVh.Backend = newVh.Backend + } + } + + // start newly added virtual hosts. + // do this before old vhosts are shut down, because if there are no overlaps we need to + // increment the listener's wait group before it's decremented by the old vhosts + // shutting down. + for _, vhost := range newL.VirtualHosts { + if matchedVhosts.Contains(vhost) { + // already serving + continue + } + + // this is a new vhost, start serving. + serveFunc, err := vhost.Serve(listenerCtx, oldL.sniServer) + if err != nil { + LoggerFromContext(listenerCtx).Warningf( + "failed to start vhost(%T, %+v): %v", + vhost, vhost.Names(), err) + continue + } + + oldL.wg.Add(1) + go (func() { + defer oldL.wg.Done() + serveFunc() + })() + + oldL.VirtualHosts = append(oldL.VirtualHosts, vhost) + + LoggerFromContext(listenerCtx).Debugf("started new vhost %+v", vhost.Names()) + } + + if len(closeVhosts) > 0 { + for vi := len(closeVhosts) - 1; vi >= 0; vi -= 1 { + idx := closeVhosts[vi] + shutdownCtx, cancel := context.WithTimeout(listenerCtx, oldL.DrainTime) + defer cancel() + + vhost := oldL.VirtualHosts[idx] + + if err := vhost.Shutdown(shutdownCtx); err != nil { + LoggerFromContext(listenerCtx).Warningf( + "failed to shutdown vhost(%T, %+v): %v", + vhost, vhost.Names(), err) + } + + LoggerFromContext(listenerCtx).Debugf("shut down removed vhost %+v", vhost.Names()) + oldL.VirtualHosts = slices.Concat( + oldL.VirtualHosts[:idx], + oldL.VirtualHosts[idx+1:], + ) + } + } + } + + LoggerFromContext(s.Context).Noticef("hot reload complete") + return nil + } + + s.Listeners = lc.Listeners + for _, initHook := range initHooks { newCtx, err := initHook(s.Context, node) if err != nil { @@ -224,6 +466,8 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error { } func (s *Server) ListenAndServeTLS() (err error) { + s.servingTLS = true + defer (func() { s.servingTLS = false })() wg := &sync.WaitGroup{} for _, listener := range s.Listeners { @@ -241,6 +485,9 @@ func (s *Server) ListenAndServeTLS() (err error) { } func (s *Server) ListenAndServe() (err error) { + s.servingHTTP = true + defer (func() { s.servingHTTP = false })() + wg := &sync.WaitGroup{} for _, listener := range s.Listeners { var netListener net.Listener @@ -343,23 +590,22 @@ func (l *Listener) ListenAndServeSNI(ctx context.Context) (err error) { l.sniServer = NewSNIListener(netListener) - wg := &sync.WaitGroup{} for _, vhost := range l.VirtualHosts { serveFunc, err := vhost.Serve(ctx, l.sniServer) if err != nil { return err } - wg.Add(1) + l.wg.Add(1) go (func() { - defer wg.Done() + defer l.wg.Done() serveFunc() })() } - wg.Add(1) + l.wg.Add(1) go (func() { - defer wg.Done() + defer l.wg.Done() err = l.sniServer.Serve() })() @@ -465,6 +711,32 @@ func (v *BaseVirtualHost) MatchesName(hostname string) bool { return false } +func (v *BaseVirtualHost) Names() (out []string) { + for _, name := range v.aliasesExact.AsSlice() { + out = append(out, name) + } + for _, re := range v.aliasesCompiled { + out = append(out, re.String()) + } + + return +} + +func virtualHostsEqual(a, b VirtualHost) bool { + if fmt.Sprintf("%T", a) != fmt.Sprintf("%T", b) { + return false + } + + aNames := hashset.FromSlice(a.Names()) + bNames := hashset.FromSlice(b.Names()) + + if !aNames.Equal(bNames) { + return false + } + + return true +} + func LoggerFromContext(ctx context.Context) log.Logger { l := ctx.Value(kLogger) if logger, ok := l.(log.Logger); ok { diff --git a/http/sni_listener.go b/http/sni_listener.go index a7d501b..df33671 100644 --- a/http/sni_listener.go +++ b/http/sni_listener.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "sync" "time" "go.fuhry.dev/runtime/utils/log" @@ -17,11 +18,27 @@ type sniConn struct { } type sniVirtualHost struct { - tlsConfig *tls.Config - sniListener *SNIListener - matcher NameMatcher - connChan chan sniConn - logger log.Logger + tlsConfig *tls.Config + sniListener *SNIListener + matcher NameMatcher + connChan chan sniConn + logger log.Logger + shutdownOnce sync.Once + accepting bool +} + +type temporaryError struct { + error +} + +var _ net.Error = &temporaryError{} + +func (e *temporaryError) Temporary() bool { + return true +} + +func (e *temporaryError) Timeout() bool { + return false } // SNIListener is a multiplexing listener that routes connections based on the @@ -31,7 +48,8 @@ type sniVirtualHost struct { // created with net.ListenTCP). type SNIListener struct { listener net.Listener - vhosts []*sniVirtualHost + vhosts map[*sniVirtualHost]struct{} + vhostsMu sync.Mutex logger log.Logger } @@ -40,6 +58,7 @@ func NewSNIListener(l net.Listener) *SNIListener { listener := &SNIListener{ listener: l, logger: logger, + vhosts: make(map[*sniVirtualHost]struct{}), } return listener @@ -49,14 +68,39 @@ func (l *SNIListener) Serve() error { for { conn, err := l.listener.Accept() if err != nil { - for _, v := range l.vhosts { - v.connChan <- sniConn{nil, err} + if netErr, ok := err.(net.Error); ok { + if netErr.Temporary() { + l.logger.Noticef("temporary error calling accept: %v", netErr) + continue + } + if netErr.Timeout() { + l.logger.Noticef("timeout error calling accept: %v", netErr) + continue + } + + l.logger.Errorf( + "permanent error calling Accept(), propagating to vhosts: %v", + netErr) + } else { + l.logger.Errorf( + "non-net.Error of type %T calling Accept(), propagating to vhosts: %v", + err, err) + } + l.vhostsMu.Lock() + defer l.vhostsMu.Unlock() + + for v := range l.vhosts { + v.msg(nil, err) } return err } - l.logger.V(2).Infof("new conn: %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - go l.handle(conn, l.logger.AppendPrefix(fmt.Sprintf("[%s]", conn.RemoteAddr()))) + l.logger.V(1).Infof("new conn: %s <-> %s, %d vhosts to check", + conn.RemoteAddr(), + conn.LocalAddr(), + len(l.vhosts)) + + go l.handle(conn, l.logger.AppendPrefix(fmt.Sprintf("<->%s", conn.RemoteAddr()))) } } @@ -87,14 +131,20 @@ func (l *SNIListener) handle(conn net.Conn, logger log.Logger) { writer: conn, } - for _, vhost := range l.vhosts { + l.vhostsMu.Lock() + defer l.vhostsMu.Unlock() + for vhost := range l.vhosts { if vhost.matcher.MatchesName(clientHello.ServerName) { - vhost.connChan <- sniConn{childConn, nil} + logger.V(1).Infof( + "new conn matched vhost %+v for name %q, dispatch", + vhost.matcher.Names(), + clientHello.ServerName) + vhost.msg(childConn, nil) return } } - logger.Warningf("no servername match: %q", clientHello.ServerName) + l.logger.Warningf("no servername match: %q", clientHello.ServerName) conn.Close() } @@ -107,7 +157,9 @@ func (l *SNIListener) AddVirtualHost(matcher NameMatcher, tlsConfig *tls.Config) logger: l.logger, } - l.vhosts = append(l.vhosts, vhost) + l.vhostsMu.Lock() + defer l.vhostsMu.Unlock() + l.vhosts[vhost] = struct{}{} return vhost, nil } @@ -126,8 +178,22 @@ func (l *SNIListener) Close() error { return l.listener.Close() } +func (l *SNIListener) deregister(v *sniVirtualHost) { + v.logger.V(1).Debugf("deregistering SNI vhost %+v from listener %v", v.matcher.Names(), l.Addr()) + + l.vhostsMu.Lock() + defer l.vhostsMu.Unlock() + + delete(l.vhosts, v) +} + func (v *sniVirtualHost) Accept() (net.Conn, error) { c := <-v.connChan + + if c.conn == nil && c.err == nil { + return nil, net.ErrClosed + } + if c.conn != nil { v.logger.V(3).Noticef("recv'd dispatched conn: %s <-> %s", c.conn.RemoteAddr(), c.conn.LocalAddr()) } else if c.err != nil { @@ -139,14 +205,7 @@ func (v *sniVirtualHost) Accept() (net.Conn, error) { return c.conn, nil } - tlsConn := tls.Server(c.conn, v.tlsConfig) - - if err := tlsConn.Handshake(); err != nil { - c.conn.Close() - return nil, err - } - - return tlsConn, nil + return tls.Server(c.conn, v.tlsConfig), nil } func (v *sniVirtualHost) Addr() net.Addr { @@ -154,7 +213,28 @@ func (v *sniVirtualHost) Addr() net.Addr { } func (v *sniVirtualHost) Close() error { - return v.sniListener.Close() + v.shutdownOnce.Do(func() { + v.sniListener.deregister(v) + + if v.connChan != nil { + v.msg(nil, net.ErrClosed) + close(v.connChan) + v.connChan = nil + } + }) + + return nil +} + +func (v *sniVirtualHost) msg(conn net.Conn, err error) { + if v.connChan != nil { + select { + case v.connChan <- sniConn{conn, err}: + // yay message sent + default: + // oops + } + } } // ref: https://www.agwa.name/blog/post/writing_an_sni_proxy_in_go diff --git a/http/tcp.go b/http/tcp.go new file mode 100644 index 0000000..2a55de5 --- /dev/null +++ b/http/tcp.go @@ -0,0 +1,169 @@ +package http + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "strings" + "sync" + "time" + + rnet "go.fuhry.dev/runtime/net" + "go.fuhry.dev/runtime/net/dns" + "go.fuhry.dev/runtime/utils/log" +) + +type SNIVirtualHost struct { + *BaseVirtualHost `yaml:",inline"` + + Backend *struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + } `yaml:"tcp_backend"` + + ctx context.Context + logger log.Logger + listener net.Listener +} + +func (v *SNIVirtualHost) Serve(ctx context.Context, sniServer *SNIListener) (func() error, error) { + if v.Backend.Port < 1 || v.Backend.Port > 65535 { + return nil, errors.New("backend port must be 1-65535") + } + if len(v.Aliases) < 1 { + return nil, errors.New("no aliases specified for TCP vhost") + } + + v.logger = LoggerFromContext(ctx).AppendPrefix(fmt.Sprintf(".SNI<%s>", v.Aliases[0])) + + sniListener, err := sniServer.AddVirtualHost(v, nil) + if err != nil { + return nil, err + } + + v.ctx = ctx + v.listener = sniListener + + return v.serve, nil +} + +func (v *SNIVirtualHost) Shutdown(ctx context.Context) (err error) { + if v.listener != nil { + err = v.listener.Close() + v.listener = nil + } + return +} + +func (v *SNIVirtualHost) serve() error { + defer v.listener.Close() + + for { + conn, err := v.listener.Accept() + if err != nil { + return err + } + + go v.conn(conn) + } +} + +func (v *SNIVirtualHost) conn(downstream net.Conn) { + defer downstream.Close() + requestID := newRequestID() + + start := time.Now() + + logEntry := map[string]string{ + "t": start.Format(time.RFC3339), + "request_id": requestID, + "remote_address": downstream.RemoteAddr().String(), + "upstream_address": "-", + } + + defer (func() { + l, _ := json.Marshal(logEntry) + v.logger.Info(string(l)) + })() + + addrs, err := v.resolve() + if err != nil { + v.logger.Errorf("[%s] error resolving backend %s:%d: %v", + requestID, v.Backend.Host, v.Backend.Port, err) + fmt.Fprintln(downstream, "error resolving backend") + return + } + + connCtx, cancel := context.WithTimeout( + v.ctx, + 5*time.Second) + defer cancel() + + upstream, err := rnet.DialHappyEyeballs(connCtx, addrs, uint16(v.Backend.Port)) + if err != nil { + v.logger.Errorf("[%s] error establishing connection to backend %s:%d: %v", + requestID, v.Backend.Host, v.Backend.Port, err) + fmt.Fprint(downstream, "error establishing connection to backend") + return + } + + logEntry["upstream_address"] = upstream.RemoteAddr().String() + + defer upstream.Close() + + wg := &sync.WaitGroup{} + wg.Add(2) + + var tx, rx int64 + txP := &tx + rxP := &rx + + go (func() { + defer wg.Done() + rx, _ := io.Copy(upstream, downstream) + *rxP = rx + })() + go (func() { + defer wg.Done() + tx, _ := io.Copy(downstream, upstream) + *txP = tx + })() + + wg.Wait() + + logEntry["bytes_received"] = fmt.Sprintf("%d", rx) + logEntry["bytes_sent"] = fmt.Sprintf("%d", tx) +} + +func (v *SNIVirtualHost) resolve() ([]net.Addr, error) { + if v.Backend.Host == "" { + return []net.Addr{ + &net.IPAddr{IP: net.IP{127, 0, 0, 1}}, + }, nil + } + + _, zone, _ := strings.Cut(v.Backend.Host, "%") + if ip := net.ParseIP(v.Backend.Host); ip != nil { + return []net.Addr{ + &net.IPAddr{IP: ip, Zone: zone}, + }, nil + } + + ip4, ip6, err := dns.ResolveDualStack(v.Backend.Host) + if err != nil { + return nil, err + } + + var out []net.Addr + if ip4 != "" { + out = append(out, &net.IPAddr{IP: net.ParseIP(ip4)}) + } + if ip6 != "" { + out = append(out, &net.IPAddr{IP: net.ParseIP(ip6)}) + } + + return out, nil +} diff --git a/sase/happy_eyeballs.go b/net/happy_eyeballs.go similarity index 92% rename from sase/happy_eyeballs.go rename to net/happy_eyeballs.go index 4e87c2e..626524a 100644 --- a/sase/happy_eyeballs.go +++ b/net/happy_eyeballs.go @@ -1,4 +1,4 @@ -package sase +package net import ( "context" @@ -13,7 +13,7 @@ import ( "go.fuhry.dev/runtime/utils/log" ) -var logger = log.WithPrefix("sase/happy_eyeballs") +var logger = log.WithPrefix("happy_eyeballs") var ( ipv4Enable, ipv6Enable bool @@ -24,7 +24,7 @@ func init() { flag.BoolVar(&ipv6Enable, "happy-eyeballs.ipv6", true, "allow happy eyeballs to dial IPv6 addresses") } -func dialHappyEyeballsNameContext(ctx context.Context, network, addr string) (net.Conn, error) { +func DialHappyEyeballsNameContext(ctx context.Context, network, addr string) (net.Conn, error) { logger.V(1).Debugf("dialHappyEyeballsNameContext: %s/%s", network, addr) if network != "tcp" { @@ -55,10 +55,10 @@ func dialHappyEyeballsNameContext(ctx context.Context, network, addr string) (ne addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv6)}) } - return dialHappyEyeballs(ctx, addrs, uint16(port)) + return DialHappyEyeballs(ctx, addrs, uint16(port)) } -func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.Conn, error) { +func DialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.Conn, error) { if !ipv4Enable && !ipv6Enable { return nil, fmt.Errorf("cannot dial happy eyeballs connection: at least one address family must be enabled") } diff --git a/sase/client.go b/sase/client.go index 6191bbd..6ac347f 100644 --- a/sase/client.go +++ b/sase/client.go @@ -8,6 +8,7 @@ import ( "github.com/gorilla/websocket" "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/net" ) type Client interface { @@ -28,7 +29,7 @@ func NewClient(id mtls.Identity) (Client, error) { } tlsConfig.RootCAs = nil wsClient := &websocket.Dialer{ - NetDialContext: dialHappyEyeballsNameContext, + NetDialContext: net.DialHappyEyeballsNameContext, TLSClientConfig: tlsConfig, }