]> go.fuhry.dev Git - runtime.git/commitdiff
[http] live reload, SNI proxying, bugfixes
authorDan Fuhry <dan@fuhry.com>
Thu, 6 Nov 2025 11:49:53 +0000 (06:49 -0500)
committerDan Fuhry <dan@fuhry.com>
Thu, 6 Nov 2025 11:49:53 +0000 (06:49 -0500)
- Initial backend live reload support, to be enabled in binary with merge of ephs/config_watcher
- Fix SNI listener closure upon single connection TLS handshake error

http/http.go
http/route_action_s3.go
http/server.go
http/sni_listener.go
http/tcp.go [new file with mode: 0644]
net/happy_eyeballs.go [moved from sase/happy_eyeballs.go with 92% similarity]
sase/client.go

index 1dfb1f34de21208617fc4f6ff86b628ef603c8ac..b65b01962c8bcf8b67a56d929d2d1899aeda30bb 100644 (file)
@@ -2,6 +2,7 @@ package http
 
 import (
        "context"
+       "crypto/tls"
        "errors"
        "fmt"
        "net"
@@ -18,7 +19,7 @@ type HTTPVirtualHost struct {
        *BaseVirtualHost `yaml:",inline"`
 
        Routes                 []*Route `yaml:"routes"`
-       Certificate            string   `yaml:"certificate"`
+       Certificate            string   `yaml:"cert"`
        TrustUpstreamRequestID *bool    `yaml:"trust_upstream_request_id"`
 
        tlsServer *http.Server
@@ -77,12 +78,13 @@ func (v *HTTPVirtualHost) NewHTTPServerWithContext(ctx context.Context) (*http.S
        return server, nil
 }
 
-func (v *HTTPVirtualHost) Shutdown(ctx context.Context) error {
+func (v *HTTPVirtualHost) Shutdown(ctx context.Context) (err error) {
        if v.tlsServer != nil {
-               return v.tlsServer.Shutdown(ctx)
+               err = v.tlsServer.Shutdown(ctx)
+               v.tlsServer = nil
        }
 
-       return nil
+       return
 }
 
 func (v *HTTPVirtualHost) handle(w http.ResponseWriter, r *http.Request) {
@@ -109,6 +111,10 @@ func (v *HTTPVirtualHost) handle(w http.ResponseWriter, r *http.Request) {
        v.fulfill(w, r, v.Routes)
 }
 
+func (v *HTTPVirtualHost) swapTlsConfig(c *tls.Config) {
+       v.tlsServer.TLSConfig = c
+}
+
 func vhostGlobToRegexp(name string) (*regexp.Regexp, error) {
        var parts []string
        for _, part := range strings.Split(name, ".") {
index 4e5b6eea0b720c0af9afeef1ef1b6034b04fac2b..40dc3a16c0dfc579dc70b3b9447c9a3d09203bb1 100644 (file)
@@ -71,7 +71,7 @@ func (a *S3Action) Handle(w http.ResponseWriter, r *http.Request, next http.Hand
        stat, err := object.Stat()
        if err != nil {
                er := minio.ToErrorResponse(err)
-               w.WriteHeader(er.StatusCode)
+               w.WriteHeader(coalesce(er.StatusCode, http.StatusInternalServerError))
                w.Write([]byte(fmt.Sprintf("failed to stat object %q: %+v", objPath, err)))
                return
        }
@@ -188,6 +188,16 @@ func s3ActionFromRouteYaml(node *yaml.Node) (RouteAction, error) {
        return nil, nil
 }
 
+func coalesce[T comparable](items ...T) T {
+       var empty T
+       for _, last := range items {
+               if last != empty {
+                       return last
+               }
+       }
+       return empty
+}
+
 func init() {
        AddRouteParseFunc(s3ActionFromRouteYaml)
 }
index 11941249a84ddd43876ef6a5fd76b2c5f95f47f9..18ae75f0b72148a51174ee5d66d9f421cceb2772 100644 (file)
@@ -15,6 +15,7 @@ import (
        "sync"
        "time"
 
+       "go.fuhry.dev/runtime/mtls"
        "go.fuhry.dev/runtime/utils/hashset"
        "go.fuhry.dev/runtime/utils/log"
        "go.fuhry.dev/runtime/utils/stringmatch"
@@ -34,6 +35,7 @@ type Route struct {
 
 type NameMatcher interface {
        MatchesName(string) bool
+       Names() []string
 }
 
 type VirtualHost interface {
@@ -56,21 +58,24 @@ type Listener struct {
        // ExternalPort is an optional port to use when sending http-to-https redirects. If external traffic
        // enters through a different port than what the server is bound to locally, this setting ensures that
        // redirects reference the correct port.
-       ExternalPort           int                         `yaml:"external_port"`
-       ProxyProtocol          bool                        `yaml:"proxy_protocol"`
-       InsecureAddr           string                      `yaml:"listen_insecure"`
-       Certificate            string                      `yaml:"cert"`
-       TrustUpstreamRequestID bool                        `yaml:"trust_upstream_request_id"`
-       VirtualHosts           []VirtualHost               `yaml:"-"`
-       HTTPVirtualHosts       map[string]*HTTPVirtualHost `yaml:"virtual_hosts"`
+       ExternalPort           int           `yaml:"external_port"`
+       ProxyProtocol          bool          `yaml:"proxy_protocol"`
+       InsecureAddr           string        `yaml:"listen_insecure"`
+       Certificate            string        `yaml:"cert"`
+       TrustUpstreamRequestID bool          `yaml:"trust_upstream_request_id"`
+       DrainTime              time.Duration `yaml:"drain_time"`
+       VirtualHosts           []VirtualHost `yaml:"-"`
 
        sniServer  *SNIListener
        httpServer *http.Server
+       wg         sync.WaitGroup
 }
 
 type Server struct {
        Listeners []*Listener     `yaml:"listeners"`
        Context   context.Context `yaml:"-"`
+
+       servingTLS, servingHTTP bool
 }
 
 type Authorization struct {
@@ -99,6 +104,10 @@ const (
        hostnameWildcardMulti  = `[A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)*`
 )
 
+const (
+       DefaultDrainTime = 30 * time.Second
+)
+
 var ErrServerClosed = http.ErrServerClosed
 var randSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
 var portSpec = regexp.MustCompile(":[0-9]{1,5}$")
@@ -160,6 +169,100 @@ func (r *Route) UnmarshalYAML(node *yaml.Node) error {
        return nil
 }
 
+func (l *Listener) UnmarshalYAML(node *yaml.Node) error {
+       var rawNode struct {
+               Addr                   string `yaml:"listen"`
+               ExternalPort           int    `yaml:"external_port"`
+               ProxyProtocol          bool   `yaml:"proxy_protocol"`
+               InsecureAddr           string `yaml:"listen_insecure"`
+               Certificate            string `yaml:"cert"`
+               TrustUpstreamRequestID bool   `yaml:"trust_upstream_request_id"`
+               DrainTime              string `yaml:"drain_time"`
+       }
+
+       var httpNode struct {
+               HTTPVirtualHosts map[string]*HTTPVirtualHost `yaml:"virtual_hosts"`
+       }
+
+       var tcpNode struct {
+               SNIVirtualHosts map[string]*SNIVirtualHost `yaml:"virtual_hosts"`
+       }
+
+       if err := node.Decode(&rawNode); err != nil {
+               return err
+       }
+
+       if err := node.Decode(&httpNode); err != nil {
+               return err
+       }
+
+       if err := node.Decode(&tcpNode); err != nil {
+               return err
+       }
+
+       l.Addr = rawNode.Addr
+       l.ExternalPort = rawNode.ExternalPort
+       l.ProxyProtocol = rawNode.ProxyProtocol
+       l.InsecureAddr = rawNode.InsecureAddr
+       l.Certificate = rawNode.Certificate
+       l.TrustUpstreamRequestID = rawNode.TrustUpstreamRequestID
+
+       if rawNode.DrainTime != "" {
+               if dur, err := time.ParseDuration(rawNode.DrainTime); err == nil {
+                       l.DrainTime = dur
+               } else {
+                       return fmt.Errorf("failed to parse drain_time: %v", err)
+               }
+       } else {
+               l.DrainTime = DefaultDrainTime
+       }
+
+       coveredVhosts := hashset.NewHashSet[string]()
+
+       if httpNode.HTTPVirtualHosts != nil {
+               for name, vhost := range httpNode.HTTPVirtualHosts {
+                       if vhost.Routes == nil {
+                               continue
+                       }
+                       if vhost.BaseVirtualHost == nil {
+                               vhost.BaseVirtualHost = &BaseVirtualHost{}
+                       }
+                       coveredVhosts.Add(name)
+                       vhost.Aliases = append(vhost.Aliases, name)
+                       l.VirtualHosts = append(l.VirtualHosts, vhost)
+               }
+       }
+       if tcpNode.SNIVirtualHosts != nil {
+               for name, vhost := range tcpNode.SNIVirtualHosts {
+                       if vhost.Backend == nil {
+                               continue
+                       }
+
+                       if vhost.BaseVirtualHost == nil {
+                               vhost.BaseVirtualHost = &BaseVirtualHost{}
+                       }
+
+                       if coveredVhosts.Contains(name) {
+                               return fmt.Errorf(
+                                       "virtual host %q contains both http routes and a SNI backend, this is not allowed",
+                                       name)
+                       }
+
+                       coveredVhosts.Add(name)
+                       vhost.Aliases = append(vhost.Aliases, name)
+                       l.VirtualHosts = append(l.VirtualHosts, vhost)
+               }
+       }
+
+       if coveredVhosts.Len() < len(httpNode.HTTPVirtualHosts) {
+               return fmt.Errorf(
+                       "%d virtual hosts are unconfigured, specify either routes for HTTP or tcp_backend "+
+                               "for SNI proxying", len(httpNode.HTTPVirtualHosts)-coveredVhosts.Len())
+       }
+
+       return nil
+}
+
 // UnmarshalYAML implements yaml.Unmarshaler
 func (s *Server) UnmarshalYAML(node *yaml.Node) error {
        var lc struct {
@@ -183,17 +286,7 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error {
                lc.Listeners = []*Listener{lc.Listener}
        }
 
-       s.Listeners = lc.Listeners
-
-       for _, listener := range s.Listeners {
-               for name, vhost := range listener.HTTPVirtualHosts {
-                       if vhost.BaseVirtualHost == nil {
-                               vhost.BaseVirtualHost = &BaseVirtualHost{}
-                       }
-                       vhost.Aliases = append(vhost.Aliases, name)
-                       listener.VirtualHosts = append(listener.VirtualHosts, vhost)
-               }
-
+       for _, listener := range lc.Listeners {
                for i, vhost := range listener.VirtualHosts {
                        switch vhost := vhost.(type) {
                        case *HTTPVirtualHost:
@@ -204,14 +297,163 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error {
                                        vhost.TrustUpstreamRequestID = &listener.TrustUpstreamRequestID
                                }
                                if err := vhost.precompileAliases(); err != nil {
-                                       return fmt.Errorf("bootstrapping vhost %d: %v", i, err)
+                                       return fmt.Errorf("bootstrapping http vhost %d: %v", i, err)
+                               }
+                               log.Default().Debugf("vhost %+v cert: %s", vhost.Names(), vhost.Certificate)
+                       case *SNIVirtualHost:
+                               if err := vhost.precompileAliases(); err != nil {
+                                       return fmt.Errorf("bootstrapping sni vhost %d: %v", i, err)
                                }
+                               log.Default().Debugf("sni vhost %+v bootatrapped", vhost.Names())
                        default:
                                return fmt.Errorf("unsupported vhost type: %T", vhost)
                        }
                }
        }
 
+       if len(s.Listeners) > 0 && (s.servingHTTP || s.servingTLS) {
+               LoggerFromContext(s.Context).Noticef("attempting hot reload on %d listeners", len(lc.Listeners))
+               if len(s.Listeners) != len(lc.Listeners) {
+                       return fmt.Errorf("listener count mismatch (%d != %d), refusing hot reload", len(s.Listeners), len(lc.Listeners))
+               }
+
+               for i := range s.Listeners {
+                       listenerCtx := context.WithValue(s.Context, kListener, s.Listeners[i])
+
+                       oldL := s.Listeners[i]
+                       newL := lc.Listeners[i]
+
+                       if oldL.Addr != newL.Addr {
+                               return fmt.Errorf("hot reload refused: listener %d: modifying address/port is not supported", i)
+                       }
+
+                       if oldL.InsecureAddr != newL.InsecureAddr {
+                               return fmt.Errorf("hot reload refused: listener %d: modifying insecure address/port is not supported", i)
+                       }
+
+                       if oldL.ProxyProtocol != newL.ProxyProtocol {
+                               return fmt.Errorf("hot reload refused: listener %d: toggling proxy protocol is not supported", i)
+                       }
+
+                       oldL.TrustUpstreamRequestID = newL.TrustUpstreamRequestID
+                       oldL.ExternalPort = newL.ExternalPort
+                       oldL.DrainTime = newL.DrainTime
+                       oldL.Certificate = newL.Certificate
+
+                       // virtual hosts that need to be started or closed
+                       var (
+                               // virtual hosts to close are tracked by index, so we can remove them from the
+                               // l.VirtualHosts slice.
+                               closeVhosts []int
+                               // keep track of which vhosts corresponded to those that were already serving, so we
+                               // don't try to start them again
+                               matchedVhosts = hashset.NewHashSet[VirtualHost]()
+                       )
+                       for vi, vh := range oldL.VirtualHosts {
+                               var match VirtualHost
+                               for _, nvh := range newL.VirtualHosts {
+                                       if virtualHostsEqual(vh, nvh) {
+                                               match = nvh
+                                               matchedVhosts.Add(nvh)
+                                               break
+                                       }
+                               }
+
+                               if match == nil {
+                                       // none of the vhosts in the new config match the current one, which means this
+                                       // vhost is going away
+                                       LoggerFromContext(s.Context).Debugf(
+                                               "listener %d: vhost(%T, %+v) is going away",
+                                               i, vh, vh.Names(),
+                                       )
+                                       closeVhosts = append(closeVhosts, vi)
+                                       continue
+                               }
+
+                               switch oldVh := vh.(type) {
+                               case *HTTPVirtualHost:
+                                       newVh := match.(*HTTPVirtualHost)
+
+                                       oldVh.Routes = newVh.Routes
+                                       oldVh.TrustUpstreamRequestID = newVh.TrustUpstreamRequestID
+                                       if oldVh.Certificate != newVh.Certificate {
+                                               newCert := mtls.NewSSLCertificate(newVh.Certificate)
+                                               if tlsConfig, err := newCert.TlsConfig(listenerCtx); err == nil {
+                                                       oldVh.swapTlsConfig(tlsConfig)
+                                               } else {
+                                                       LoggerFromContext(listenerCtx).Errorf(
+                                                               "vhost %+v: error loading new ssl certificate %q: %v",
+                                                               oldVh.Names(), newVh.Certificate, err)
+                                               }
+                                       }
+                                       LoggerFromContext(listenerCtx).Debugf(
+                                               "loaded %d new routes for HTTP vhost %+v",
+                                               len(oldVh.Routes), oldVh.Names())
+                               case *SNIVirtualHost:
+                                       newVh := match.(*SNIVirtualHost)
+                                       oldVh.Backend = newVh.Backend
+                               }
+                       }
+
+                       // start newly added virtual hosts.
+                       // do this before old vhosts are shut down, because if there are no overlaps we need to
+                       // increment the listener's wait group before it's decremented by the old vhosts
+                       // shutting down.
+                       for _, vhost := range newL.VirtualHosts {
+                               if matchedVhosts.Contains(vhost) {
+                                       // already serving
+                                       continue
+                               }
+
+                               // this is a new vhost, start serving.
+                               serveFunc, err := vhost.Serve(listenerCtx, oldL.sniServer)
+                               if err != nil {
+                                       LoggerFromContext(listenerCtx).Warningf(
+                                               "failed to start vhost(%T, %+v): %v",
+                                               vhost, vhost.Names(), err)
+                                       continue
+                               }
+
+                               oldL.wg.Add(1)
+                               go (func() {
+                                       defer oldL.wg.Done()
+                                       serveFunc()
+                               })()
+
+                               oldL.VirtualHosts = append(oldL.VirtualHosts, vhost)
+
+                               LoggerFromContext(listenerCtx).Debugf("started new vhost %+v", vhost.Names())
+                       }
+
+                       if len(closeVhosts) > 0 {
+                               for vi := len(closeVhosts) - 1; vi >= 0; vi -= 1 {
+                                       idx := closeVhosts[vi]
+                                       shutdownCtx, cancel := context.WithTimeout(listenerCtx, oldL.DrainTime)
+                                       defer cancel()
+
+                                       vhost := oldL.VirtualHosts[idx]
+
+                                       if err := vhost.Shutdown(shutdownCtx); err != nil {
+                                               LoggerFromContext(listenerCtx).Warningf(
+                                                       "failed to shutdown vhost(%T, %+v): %v",
+                                                       vhost, vhost.Names(), err)
+                                       }
+
+                                       LoggerFromContext(listenerCtx).Debugf("shut down removed vhost %+v", vhost.Names())
+                                       oldL.VirtualHosts = slices.Concat(
+                                               oldL.VirtualHosts[:idx],
+                                               oldL.VirtualHosts[idx+1:],
+                                       )
+                               }
+                       }
+               }
+
+               LoggerFromContext(s.Context).Noticef("hot reload complete")
+               return nil
+       }
+
+       s.Listeners = lc.Listeners
+
        for _, initHook := range initHooks {
                newCtx, err := initHook(s.Context, node)
                if err != nil {
@@ -224,6 +466,8 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error {
 }
 
 func (s *Server) ListenAndServeTLS() (err error) {
+       s.servingTLS = true
+       defer (func() { s.servingTLS = false })()
        wg := &sync.WaitGroup{}
 
        for _, listener := range s.Listeners {
@@ -241,6 +485,9 @@ func (s *Server) ListenAndServeTLS() (err error) {
 }
 
 func (s *Server) ListenAndServe() (err error) {
+       s.servingHTTP = true
+       defer (func() { s.servingHTTP = false })()
+
        wg := &sync.WaitGroup{}
        for _, listener := range s.Listeners {
                var netListener net.Listener
@@ -343,23 +590,22 @@ func (l *Listener) ListenAndServeSNI(ctx context.Context) (err error) {
 
        l.sniServer = NewSNIListener(netListener)
 
-       wg := &sync.WaitGroup{}
        for _, vhost := range l.VirtualHosts {
                serveFunc, err := vhost.Serve(ctx, l.sniServer)
                if err != nil {
                        return err
                }
 
-               wg.Add(1)
+               l.wg.Add(1)
                go (func() {
-                       defer wg.Done()
+                       defer l.wg.Done()
                        serveFunc()
                })()
        }
 
-       wg.Add(1)
+       l.wg.Add(1)
        go (func() {
-               defer wg.Done()
+               defer l.wg.Done()
                err = l.sniServer.Serve()
        })()
 
@@ -465,6 +711,32 @@ func (v *BaseVirtualHost) MatchesName(hostname string) bool {
        return false
 }
 
+func (v *BaseVirtualHost) Names() (out []string) {
+       for _, name := range v.aliasesExact.AsSlice() {
+               out = append(out, name)
+       }
+       for _, re := range v.aliasesCompiled {
+               out = append(out, re.String())
+       }
+
+       return
+}
+
+func virtualHostsEqual(a, b VirtualHost) bool {
+       if fmt.Sprintf("%T", a) != fmt.Sprintf("%T", b) {
+               return false
+       }
+
+       aNames := hashset.FromSlice(a.Names())
+       bNames := hashset.FromSlice(b.Names())
+
+       if !aNames.Equal(bNames) {
+               return false
+       }
+
+       return true
+}
+
 func LoggerFromContext(ctx context.Context) log.Logger {
        l := ctx.Value(kLogger)
        if logger, ok := l.(log.Logger); ok {
index a7d501b6c94d8583c120d9a861fcd80d6fdd384a..df33671b9f37247195da867388d16b5efd039de2 100644 (file)
@@ -6,6 +6,7 @@ import (
        "fmt"
        "io"
        "net"
+       "sync"
        "time"
 
        "go.fuhry.dev/runtime/utils/log"
@@ -17,11 +18,27 @@ type sniConn struct {
 }
 
 type sniVirtualHost struct {
-       tlsConfig   *tls.Config
-       sniListener *SNIListener
-       matcher     NameMatcher
-       connChan    chan sniConn
-       logger      log.Logger
+       tlsConfig    *tls.Config
+       sniListener  *SNIListener
+       matcher      NameMatcher
+       connChan     chan sniConn
+       logger       log.Logger
+       shutdownOnce sync.Once
+       accepting    bool
+}
+
+type temporaryError struct {
+       error
+}
+
+var _ net.Error = &temporaryError{}
+
+func (e *temporaryError) Temporary() bool {
+       return true
+}
+
+func (e *temporaryError) Timeout() bool {
+       return false
 }
 
 // SNIListener is a multiplexing listener that routes connections based on the
@@ -31,7 +48,8 @@ type sniVirtualHost struct {
 // created with net.ListenTCP).
 type SNIListener struct {
        listener net.Listener
-       vhosts   []*sniVirtualHost
+       vhosts   map[*sniVirtualHost]struct{}
+       vhostsMu sync.Mutex
        logger   log.Logger
 }
 
@@ -40,6 +58,7 @@ func NewSNIListener(l net.Listener) *SNIListener {
        listener := &SNIListener{
                listener: l,
                logger:   logger,
+               vhosts:   make(map[*sniVirtualHost]struct{}),
        }
 
        return listener
@@ -49,14 +68,39 @@ func (l *SNIListener) Serve() error {
        for {
                conn, err := l.listener.Accept()
                if err != nil {
-                       for _, v := range l.vhosts {
-                               v.connChan <- sniConn{nil, err}
+                       if netErr, ok := err.(net.Error); ok {
+                               if netErr.Temporary() {
+                                       l.logger.Noticef("temporary error calling accept: %v", netErr)
+                                       continue
+                               }
+                               if netErr.Timeout() {
+                                       l.logger.Noticef("timeout error calling accept: %v", netErr)
+                                       continue
+                               }
+
+                               l.logger.Errorf(
+                                       "permanent error calling Accept(), propagating to vhosts: %v",
+                                       netErr)
+                       } else {
+                               l.logger.Errorf(
+                                       "non-net.Error of type %T calling Accept(), propagating to vhosts: %v",
+                                       err, err)
+                       }
+                       l.vhostsMu.Lock()
+                       defer l.vhostsMu.Unlock()
+
+                       for v := range l.vhosts {
+                               v.msg(nil, err)
                        }
                        return err
                }
 
-               l.logger.V(2).Infof("new conn: %s <-> %s", conn.RemoteAddr(), conn.LocalAddr())
-               go l.handle(conn, l.logger.AppendPrefix(fmt.Sprintf("[%s]", conn.RemoteAddr())))
+               l.logger.V(1).Infof("new conn: %s <-> %s, %d vhosts to check",
+                       conn.RemoteAddr(),
+                       conn.LocalAddr(),
+                       len(l.vhosts))
+
+               go l.handle(conn, l.logger.AppendPrefix(fmt.Sprintf("<->%s", conn.RemoteAddr())))
        }
 }
 
@@ -87,14 +131,20 @@ func (l *SNIListener) handle(conn net.Conn, logger log.Logger) {
                writer: conn,
        }
 
-       for _, vhost := range l.vhosts {
+       l.vhostsMu.Lock()
+       defer l.vhostsMu.Unlock()
+       for vhost := range l.vhosts {
                if vhost.matcher.MatchesName(clientHello.ServerName) {
-                       vhost.connChan <- sniConn{childConn, nil}
+                       logger.V(1).Infof(
+                               "new conn matched vhost %+v for name %q, dispatch",
+                               vhost.matcher.Names(),
+                               clientHello.ServerName)
+                       vhost.msg(childConn, nil)
                        return
                }
        }
 
-       logger.Warningf("no servername match: %q", clientHello.ServerName)
+       l.logger.Warningf("no servername match: %q", clientHello.ServerName)
        conn.Close()
 }
 
@@ -107,7 +157,9 @@ func (l *SNIListener) AddVirtualHost(matcher NameMatcher, tlsConfig *tls.Config)
                logger:      l.logger,
        }
 
-       l.vhosts = append(l.vhosts, vhost)
+       l.vhostsMu.Lock()
+       defer l.vhostsMu.Unlock()
+       l.vhosts[vhost] = struct{}{}
 
        return vhost, nil
 }
@@ -126,8 +178,22 @@ func (l *SNIListener) Close() error {
        return l.listener.Close()
 }
 
+func (l *SNIListener) deregister(v *sniVirtualHost) {
+       v.logger.V(1).Debugf("deregistering SNI vhost %+v from listener %v", v.matcher.Names(), l.Addr())
+
+       l.vhostsMu.Lock()
+       defer l.vhostsMu.Unlock()
+
+       delete(l.vhosts, v)
+}
+
 func (v *sniVirtualHost) Accept() (net.Conn, error) {
        c := <-v.connChan
+
+       if c.conn == nil && c.err == nil {
+               return nil, net.ErrClosed
+       }
+
        if c.conn != nil {
                v.logger.V(3).Noticef("recv'd dispatched conn: %s <-> %s", c.conn.RemoteAddr(), c.conn.LocalAddr())
        } else if c.err != nil {
@@ -139,14 +205,7 @@ func (v *sniVirtualHost) Accept() (net.Conn, error) {
                return c.conn, nil
        }
 
-       tlsConn := tls.Server(c.conn, v.tlsConfig)
-
-       if err := tlsConn.Handshake(); err != nil {
-               c.conn.Close()
-               return nil, err
-       }
-
-       return tlsConn, nil
+       return tls.Server(c.conn, v.tlsConfig), nil
 }
 
 func (v *sniVirtualHost) Addr() net.Addr {
@@ -154,7 +213,28 @@ func (v *sniVirtualHost) Addr() net.Addr {
 }
 
 func (v *sniVirtualHost) Close() error {
-       return v.sniListener.Close()
+       v.shutdownOnce.Do(func() {
+               v.sniListener.deregister(v)
+
+               if v.connChan != nil {
+                       v.msg(nil, net.ErrClosed)
+                       close(v.connChan)
+                       v.connChan = nil
+               }
+       })
+
+       return nil
+}
+
+func (v *sniVirtualHost) msg(conn net.Conn, err error) {
+       if v.connChan != nil {
+               select {
+               case v.connChan <- sniConn{conn, err}:
+                       // yay message sent
+               default:
+                       // oops
+               }
+       }
 }
 
 // ref: https://www.agwa.name/blog/post/writing_an_sni_proxy_in_go
diff --git a/http/tcp.go b/http/tcp.go
new file mode 100644 (file)
index 0000000..2a55de5
--- /dev/null
@@ -0,0 +1,169 @@
+package http
+
+import (
+       "context"
+       "encoding/json"
+       "errors"
+       "fmt"
+       "io"
+       "net"
+       "strings"
+       "sync"
+       "time"
+
+       rnet "go.fuhry.dev/runtime/net"
+       "go.fuhry.dev/runtime/net/dns"
+       "go.fuhry.dev/runtime/utils/log"
+)
+
+type SNIVirtualHost struct {
+       *BaseVirtualHost `yaml:",inline"`
+
+       Backend *struct {
+               Host string `yaml:"host"`
+               Port int    `yaml:"port"`
+       } `yaml:"tcp_backend"`
+
+       ctx      context.Context
+       logger   log.Logger
+       listener net.Listener
+}
+
+func (v *SNIVirtualHost) Serve(ctx context.Context, sniServer *SNIListener) (func() error, error) {
+       if v.Backend.Port < 1 || v.Backend.Port > 65535 {
+               return nil, errors.New("backend port must be 1-65535")
+       }
+       if len(v.Aliases) < 1 {
+               return nil, errors.New("no aliases specified for TCP vhost")
+       }
+
+       v.logger = LoggerFromContext(ctx).AppendPrefix(fmt.Sprintf(".SNI<%s>", v.Aliases[0]))
+
+       sniListener, err := sniServer.AddVirtualHost(v, nil)
+       if err != nil {
+               return nil, err
+       }
+
+       v.ctx = ctx
+       v.listener = sniListener
+
+       return v.serve, nil
+}
+
+func (v *SNIVirtualHost) Shutdown(ctx context.Context) (err error) {
+       if v.listener != nil {
+               err = v.listener.Close()
+               v.listener = nil
+       }
+       return
+}
+
+func (v *SNIVirtualHost) serve() error {
+       defer v.listener.Close()
+
+       for {
+               conn, err := v.listener.Accept()
+               if err != nil {
+                       return err
+               }
+
+               go v.conn(conn)
+       }
+}
+
+func (v *SNIVirtualHost) conn(downstream net.Conn) {
+       defer downstream.Close()
+       requestID := newRequestID()
+
+       start := time.Now()
+
+       logEntry := map[string]string{
+               "t":                start.Format(time.RFC3339),
+               "request_id":       requestID,
+               "remote_address":   downstream.RemoteAddr().String(),
+               "upstream_address": "-",
+       }
+
+       defer (func() {
+               l, _ := json.Marshal(logEntry)
+               v.logger.Info(string(l))
+       })()
+
+       addrs, err := v.resolve()
+       if err != nil {
+               v.logger.Errorf("[%s] error resolving backend %s:%d: %v",
+                       requestID, v.Backend.Host, v.Backend.Port, err)
+               fmt.Fprintln(downstream, "error resolving backend")
+               return
+       }
+
+       connCtx, cancel := context.WithTimeout(
+               v.ctx,
+               5*time.Second)
+       defer cancel()
+
+       upstream, err := rnet.DialHappyEyeballs(connCtx, addrs, uint16(v.Backend.Port))
+       if err != nil {
+               v.logger.Errorf("[%s] error establishing connection to backend %s:%d: %v",
+                       requestID, v.Backend.Host, v.Backend.Port, err)
+               fmt.Fprint(downstream, "error establishing connection to backend")
+               return
+       }
+
+       logEntry["upstream_address"] = upstream.RemoteAddr().String()
+
+       defer upstream.Close()
+
+       wg := &sync.WaitGroup{}
+       wg.Add(2)
+
+       var tx, rx int64
+       txP := &tx
+       rxP := &rx
+
+       go (func() {
+               defer wg.Done()
+               rx, _ := io.Copy(upstream, downstream)
+               *rxP = rx
+       })()
+       go (func() {
+               defer wg.Done()
+               tx, _ := io.Copy(downstream, upstream)
+               *txP = tx
+       })()
+
+       wg.Wait()
+
+       logEntry["bytes_received"] = fmt.Sprintf("%d", rx)
+       logEntry["bytes_sent"] = fmt.Sprintf("%d", tx)
+}
+
+func (v *SNIVirtualHost) resolve() ([]net.Addr, error) {
+       if v.Backend.Host == "" {
+               return []net.Addr{
+                       &net.IPAddr{IP: net.IP{127, 0, 0, 1}},
+               }, nil
+       }
+
+       _, zone, _ := strings.Cut(v.Backend.Host, "%")
+       if ip := net.ParseIP(v.Backend.Host); ip != nil {
+               return []net.Addr{
+                       &net.IPAddr{IP: ip, Zone: zone},
+               }, nil
+       }
+
+       ip4, ip6, err := dns.ResolveDualStack(v.Backend.Host)
+       if err != nil {
+               return nil, err
+       }
+
+       var out []net.Addr
+       if ip4 != "" {
+               out = append(out, &net.IPAddr{IP: net.ParseIP(ip4)})
+       }
+       if ip6 != "" {
+               out = append(out, &net.IPAddr{IP: net.ParseIP(ip6)})
+       }
+
+       return out, nil
+}
similarity index 92%
rename from sase/happy_eyeballs.go
rename to net/happy_eyeballs.go
index 4e87c2e7c614c357e49e408ea73df9158fbac16c..626524ac3e26d4b351e3e9dd5611aca22a996a3f 100644 (file)
@@ -1,4 +1,4 @@
-package sase
+package net
 
 import (
        "context"
@@ -13,7 +13,7 @@ import (
        "go.fuhry.dev/runtime/utils/log"
 )
 
-var logger = log.WithPrefix("sase/happy_eyeballs")
+var logger = log.WithPrefix("happy_eyeballs")
 var (
        ipv4Enable,
        ipv6Enable bool
@@ -24,7 +24,7 @@ func init() {
        flag.BoolVar(&ipv6Enable, "happy-eyeballs.ipv6", true, "allow happy eyeballs to dial IPv6 addresses")
 }
 
-func dialHappyEyeballsNameContext(ctx context.Context, network, addr string) (net.Conn, error) {
+func DialHappyEyeballsNameContext(ctx context.Context, network, addr string) (net.Conn, error) {
        logger.V(1).Debugf("dialHappyEyeballsNameContext: %s/%s", network, addr)
 
        if network != "tcp" {
@@ -55,10 +55,10 @@ func dialHappyEyeballsNameContext(ctx context.Context, network, addr string) (ne
                addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv6)})
        }
 
-       return dialHappyEyeballs(ctx, addrs, uint16(port))
+       return DialHappyEyeballs(ctx, addrs, uint16(port))
 }
 
-func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.Conn, error) {
+func DialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.Conn, error) {
        if !ipv4Enable && !ipv6Enable {
                return nil, fmt.Errorf("cannot dial happy eyeballs connection: at least one address family must be enabled")
        }
index 6191bbdb5837f633754c43b58424301dbfc6a81e..6ac347f951deeb601eb2f9f3da0c77f835bef5ff 100644 (file)
@@ -8,6 +8,7 @@ import (
 
        "github.com/gorilla/websocket"
        "go.fuhry.dev/runtime/mtls"
+       "go.fuhry.dev/runtime/net"
 )
 
 type Client interface {
@@ -28,7 +29,7 @@ func NewClient(id mtls.Identity) (Client, error) {
        }
        tlsConfig.RootCAs = nil
        wsClient := &websocket.Dialer{
-               NetDialContext:  dialHappyEyeballsNameContext,
+               NetDialContext:  net.DialHappyEyeballsNameContext,
                TLSClientConfig: tlsConfig,
        }