"sync"
"time"
+ "go.fuhry.dev/runtime/mtls"
"go.fuhry.dev/runtime/utils/hashset"
"go.fuhry.dev/runtime/utils/log"
"go.fuhry.dev/runtime/utils/stringmatch"
type NameMatcher interface {
MatchesName(string) bool
+ Names() []string
}
type VirtualHost interface {
// ExternalPort is an optional port to use when sending http-to-https redirects. If external traffic
// enters through a different port than what the server is bound to locally, this setting ensures that
// redirects reference the correct port.
- ExternalPort int `yaml:"external_port"`
- ProxyProtocol bool `yaml:"proxy_protocol"`
- InsecureAddr string `yaml:"listen_insecure"`
- Certificate string `yaml:"cert"`
- TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"`
- VirtualHosts []VirtualHost `yaml:"-"`
- HTTPVirtualHosts map[string]*HTTPVirtualHost `yaml:"virtual_hosts"`
+ ExternalPort int `yaml:"external_port"`
+ ProxyProtocol bool `yaml:"proxy_protocol"`
+ InsecureAddr string `yaml:"listen_insecure"`
+ Certificate string `yaml:"cert"`
+ TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"`
+ DrainTime time.Duration `yaml:"drain_time"`
+ VirtualHosts []VirtualHost `yaml:"-"`
sniServer *SNIListener
httpServer *http.Server
+ wg sync.WaitGroup
}
type Server struct {
Listeners []*Listener `yaml:"listeners"`
Context context.Context `yaml:"-"`
+
+ servingTLS, servingHTTP bool
}
type Authorization struct {
hostnameWildcardMulti = `[A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)*`
)
+const (
+ DefaultDrainTime = 30 * time.Second
+)
+
var ErrServerClosed = http.ErrServerClosed
var randSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
var portSpec = regexp.MustCompile(":[0-9]{1,5}$")
return nil
}
+func (l *Listener) UnmarshalYAML(node *yaml.Node) error {
+ var rawNode struct {
+ Addr string `yaml:"listen"`
+ ExternalPort int `yaml:"external_port"`
+ ProxyProtocol bool `yaml:"proxy_protocol"`
+ InsecureAddr string `yaml:"listen_insecure"`
+ Certificate string `yaml:"cert"`
+ TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"`
+ DrainTime string `yaml:"drain_time"`
+ }
+
+ var httpNode struct {
+ HTTPVirtualHosts map[string]*HTTPVirtualHost `yaml:"virtual_hosts"`
+ }
+
+ var tcpNode struct {
+ SNIVirtualHosts map[string]*SNIVirtualHost `yaml:"virtual_hosts"`
+ }
+
+ if err := node.Decode(&rawNode); err != nil {
+ return err
+ }
+
+ if err := node.Decode(&httpNode); err != nil {
+ return err
+ }
+
+ if err := node.Decode(&tcpNode); err != nil {
+ return err
+ }
+
+ l.Addr = rawNode.Addr
+ l.ExternalPort = rawNode.ExternalPort
+ l.ProxyProtocol = rawNode.ProxyProtocol
+ l.InsecureAddr = rawNode.InsecureAddr
+ l.Certificate = rawNode.Certificate
+ l.TrustUpstreamRequestID = rawNode.TrustUpstreamRequestID
+
+ if rawNode.DrainTime != "" {
+ if dur, err := time.ParseDuration(rawNode.DrainTime); err == nil {
+ l.DrainTime = dur
+ } else {
+ return fmt.Errorf("failed to parse drain_time: %v", err)
+ }
+ } else {
+ l.DrainTime = DefaultDrainTime
+ }
+
+ coveredVhosts := hashset.NewHashSet[string]()
+
+ if httpNode.HTTPVirtualHosts != nil {
+ for name, vhost := range httpNode.HTTPVirtualHosts {
+ if vhost.Routes == nil {
+ continue
+ }
+ if vhost.BaseVirtualHost == nil {
+ vhost.BaseVirtualHost = &BaseVirtualHost{}
+ }
+ coveredVhosts.Add(name)
+ vhost.Aliases = append(vhost.Aliases, name)
+ l.VirtualHosts = append(l.VirtualHosts, vhost)
+ }
+ }
+ if tcpNode.SNIVirtualHosts != nil {
+ for name, vhost := range tcpNode.SNIVirtualHosts {
+ if vhost.Backend == nil {
+ continue
+ }
+
+ if vhost.BaseVirtualHost == nil {
+ vhost.BaseVirtualHost = &BaseVirtualHost{}
+ }
+
+ if coveredVhosts.Contains(name) {
+ return fmt.Errorf(
+ "virtual host %q contains both http routes and a SNI backend, this is not allowed",
+ name)
+ }
+
+ coveredVhosts.Add(name)
+ vhost.Aliases = append(vhost.Aliases, name)
+ l.VirtualHosts = append(l.VirtualHosts, vhost)
+ }
+ }
+
+ if coveredVhosts.Len() < len(httpNode.HTTPVirtualHosts) {
+ return fmt.Errorf(
+ "%d virtual hosts are unconfigured, specify either routes for HTTP or tcp_backend "+
+ "for SNI proxying", len(httpNode.HTTPVirtualHosts)-coveredVhosts.Len())
+ }
+
+ return nil
+}
+
// UnmarshalYAML implements yaml.Unmarshaler
func (s *Server) UnmarshalYAML(node *yaml.Node) error {
var lc struct {
lc.Listeners = []*Listener{lc.Listener}
}
- s.Listeners = lc.Listeners
-
- for _, listener := range s.Listeners {
- for name, vhost := range listener.HTTPVirtualHosts {
- if vhost.BaseVirtualHost == nil {
- vhost.BaseVirtualHost = &BaseVirtualHost{}
- }
- vhost.Aliases = append(vhost.Aliases, name)
- listener.VirtualHosts = append(listener.VirtualHosts, vhost)
- }
-
+ for _, listener := range lc.Listeners {
for i, vhost := range listener.VirtualHosts {
switch vhost := vhost.(type) {
case *HTTPVirtualHost:
vhost.TrustUpstreamRequestID = &listener.TrustUpstreamRequestID
}
if err := vhost.precompileAliases(); err != nil {
- return fmt.Errorf("bootstrapping vhost %d: %v", i, err)
+ return fmt.Errorf("bootstrapping http vhost %d: %v", i, err)
+ }
+ log.Default().Debugf("vhost %+v cert: %s", vhost.Names(), vhost.Certificate)
+ case *SNIVirtualHost:
+ if err := vhost.precompileAliases(); err != nil {
+ return fmt.Errorf("bootstrapping sni vhost %d: %v", i, err)
}
+ log.Default().Debugf("sni vhost %+v bootatrapped", vhost.Names())
default:
return fmt.Errorf("unsupported vhost type: %T", vhost)
}
}
}
+ if len(s.Listeners) > 0 && (s.servingHTTP || s.servingTLS) {
+ LoggerFromContext(s.Context).Noticef("attempting hot reload on %d listeners", len(lc.Listeners))
+ if len(s.Listeners) != len(lc.Listeners) {
+ return fmt.Errorf("listener count mismatch (%d != %d), refusing hot reload", len(s.Listeners), len(lc.Listeners))
+ }
+
+ for i := range s.Listeners {
+ listenerCtx := context.WithValue(s.Context, kListener, s.Listeners[i])
+
+ oldL := s.Listeners[i]
+ newL := lc.Listeners[i]
+
+ if oldL.Addr != newL.Addr {
+ return fmt.Errorf("hot reload refused: listener %d: modifying address/port is not supported", i)
+ }
+
+ if oldL.InsecureAddr != newL.InsecureAddr {
+ return fmt.Errorf("hot reload refused: listener %d: modifying insecure address/port is not supported", i)
+ }
+
+ if oldL.ProxyProtocol != newL.ProxyProtocol {
+ return fmt.Errorf("hot reload refused: listener %d: toggling proxy protocol is not supported", i)
+ }
+
+ oldL.TrustUpstreamRequestID = newL.TrustUpstreamRequestID
+ oldL.ExternalPort = newL.ExternalPort
+ oldL.DrainTime = newL.DrainTime
+ oldL.Certificate = newL.Certificate
+
+ // virtual hosts that need to be started or closed
+ var (
+ // virtual hosts to close are tracked by index, so we can remove them from the
+ // l.VirtualHosts slice.
+ closeVhosts []int
+ // keep track of which vhosts corresponded to those that were already serving, so we
+ // don't try to start them again
+ matchedVhosts = hashset.NewHashSet[VirtualHost]()
+ )
+ for vi, vh := range oldL.VirtualHosts {
+ var match VirtualHost
+ for _, nvh := range newL.VirtualHosts {
+ if virtualHostsEqual(vh, nvh) {
+ match = nvh
+ matchedVhosts.Add(nvh)
+ break
+ }
+ }
+
+ if match == nil {
+ // none of the vhosts in the new config match the current one, which means this
+ // vhost is going away
+ LoggerFromContext(s.Context).Debugf(
+ "listener %d: vhost(%T, %+v) is going away",
+ i, vh, vh.Names(),
+ )
+ closeVhosts = append(closeVhosts, vi)
+ continue
+ }
+
+ switch oldVh := vh.(type) {
+ case *HTTPVirtualHost:
+ newVh := match.(*HTTPVirtualHost)
+
+ oldVh.Routes = newVh.Routes
+ oldVh.TrustUpstreamRequestID = newVh.TrustUpstreamRequestID
+ if oldVh.Certificate != newVh.Certificate {
+ newCert := mtls.NewSSLCertificate(newVh.Certificate)
+ if tlsConfig, err := newCert.TlsConfig(listenerCtx); err == nil {
+ oldVh.swapTlsConfig(tlsConfig)
+ } else {
+ LoggerFromContext(listenerCtx).Errorf(
+ "vhost %+v: error loading new ssl certificate %q: %v",
+ oldVh.Names(), newVh.Certificate, err)
+ }
+ }
+ LoggerFromContext(listenerCtx).Debugf(
+ "loaded %d new routes for HTTP vhost %+v",
+ len(oldVh.Routes), oldVh.Names())
+ case *SNIVirtualHost:
+ newVh := match.(*SNIVirtualHost)
+ oldVh.Backend = newVh.Backend
+ }
+ }
+
+ // start newly added virtual hosts.
+ // do this before old vhosts are shut down, because if there are no overlaps we need to
+ // increment the listener's wait group before it's decremented by the old vhosts
+ // shutting down.
+ for _, vhost := range newL.VirtualHosts {
+ if matchedVhosts.Contains(vhost) {
+ // already serving
+ continue
+ }
+
+ // this is a new vhost, start serving.
+ serveFunc, err := vhost.Serve(listenerCtx, oldL.sniServer)
+ if err != nil {
+ LoggerFromContext(listenerCtx).Warningf(
+ "failed to start vhost(%T, %+v): %v",
+ vhost, vhost.Names(), err)
+ continue
+ }
+
+ oldL.wg.Add(1)
+ go (func() {
+ defer oldL.wg.Done()
+ serveFunc()
+ })()
+
+ oldL.VirtualHosts = append(oldL.VirtualHosts, vhost)
+
+ LoggerFromContext(listenerCtx).Debugf("started new vhost %+v", vhost.Names())
+ }
+
+ if len(closeVhosts) > 0 {
+ for vi := len(closeVhosts) - 1; vi >= 0; vi -= 1 {
+ idx := closeVhosts[vi]
+ shutdownCtx, cancel := context.WithTimeout(listenerCtx, oldL.DrainTime)
+ defer cancel()
+
+ vhost := oldL.VirtualHosts[idx]
+
+ if err := vhost.Shutdown(shutdownCtx); err != nil {
+ LoggerFromContext(listenerCtx).Warningf(
+ "failed to shutdown vhost(%T, %+v): %v",
+ vhost, vhost.Names(), err)
+ }
+
+ LoggerFromContext(listenerCtx).Debugf("shut down removed vhost %+v", vhost.Names())
+ oldL.VirtualHosts = slices.Concat(
+ oldL.VirtualHosts[:idx],
+ oldL.VirtualHosts[idx+1:],
+ )
+ }
+ }
+ }
+
+ LoggerFromContext(s.Context).Noticef("hot reload complete")
+ return nil
+ }
+
+ s.Listeners = lc.Listeners
+
for _, initHook := range initHooks {
newCtx, err := initHook(s.Context, node)
if err != nil {
}
func (s *Server) ListenAndServeTLS() (err error) {
+ s.servingTLS = true
+ defer (func() { s.servingTLS = false })()
wg := &sync.WaitGroup{}
for _, listener := range s.Listeners {
}
func (s *Server) ListenAndServe() (err error) {
+ s.servingHTTP = true
+ defer (func() { s.servingHTTP = false })()
+
wg := &sync.WaitGroup{}
for _, listener := range s.Listeners {
var netListener net.Listener
l.sniServer = NewSNIListener(netListener)
- wg := &sync.WaitGroup{}
for _, vhost := range l.VirtualHosts {
serveFunc, err := vhost.Serve(ctx, l.sniServer)
if err != nil {
return err
}
- wg.Add(1)
+ l.wg.Add(1)
go (func() {
- defer wg.Done()
+ defer l.wg.Done()
serveFunc()
})()
}
- wg.Add(1)
+ l.wg.Add(1)
go (func() {
- defer wg.Done()
+ defer l.wg.Done()
err = l.sniServer.Serve()
})()
return false
}
+func (v *BaseVirtualHost) Names() (out []string) {
+ for _, name := range v.aliasesExact.AsSlice() {
+ out = append(out, name)
+ }
+ for _, re := range v.aliasesCompiled {
+ out = append(out, re.String())
+ }
+
+ return
+}
+
+func virtualHostsEqual(a, b VirtualHost) bool {
+ if fmt.Sprintf("%T", a) != fmt.Sprintf("%T", b) {
+ return false
+ }
+
+ aNames := hashset.FromSlice(a.Names())
+ bNames := hashset.FromSlice(b.Names())
+
+ if !aNames.Equal(bNames) {
+ return false
+ }
+
+ return true
+}
+
func LoggerFromContext(ctx context.Context) log.Logger {
l := ctx.Value(kLogger)
if logger, ok := l.(log.Logger); ok {
"fmt"
"io"
"net"
+ "sync"
"time"
"go.fuhry.dev/runtime/utils/log"
}
type sniVirtualHost struct {
- tlsConfig *tls.Config
- sniListener *SNIListener
- matcher NameMatcher
- connChan chan sniConn
- logger log.Logger
+ tlsConfig *tls.Config
+ sniListener *SNIListener
+ matcher NameMatcher
+ connChan chan sniConn
+ logger log.Logger
+ shutdownOnce sync.Once
+ accepting bool
+}
+
+type temporaryError struct {
+ error
+}
+
+var _ net.Error = &temporaryError{}
+
+func (e *temporaryError) Temporary() bool {
+ return true
+}
+
+func (e *temporaryError) Timeout() bool {
+ return false
}
// SNIListener is a multiplexing listener that routes connections based on the
// created with net.ListenTCP).
type SNIListener struct {
listener net.Listener
- vhosts []*sniVirtualHost
+ vhosts map[*sniVirtualHost]struct{}
+ vhostsMu sync.Mutex
logger log.Logger
}
listener := &SNIListener{
listener: l,
logger: logger,
+ vhosts: make(map[*sniVirtualHost]struct{}),
}
return listener
for {
conn, err := l.listener.Accept()
if err != nil {
- for _, v := range l.vhosts {
- v.connChan <- sniConn{nil, err}
+ if netErr, ok := err.(net.Error); ok {
+ if netErr.Temporary() {
+ l.logger.Noticef("temporary error calling accept: %v", netErr)
+ continue
+ }
+ if netErr.Timeout() {
+ l.logger.Noticef("timeout error calling accept: %v", netErr)
+ continue
+ }
+
+ l.logger.Errorf(
+ "permanent error calling Accept(), propagating to vhosts: %v",
+ netErr)
+ } else {
+ l.logger.Errorf(
+ "non-net.Error of type %T calling Accept(), propagating to vhosts: %v",
+ err, err)
+ }
+ l.vhostsMu.Lock()
+ defer l.vhostsMu.Unlock()
+
+ for v := range l.vhosts {
+ v.msg(nil, err)
}
return err
}
- l.logger.V(2).Infof("new conn: %s <-> %s", conn.RemoteAddr(), conn.LocalAddr())
- go l.handle(conn, l.logger.AppendPrefix(fmt.Sprintf("[%s]", conn.RemoteAddr())))
+ l.logger.V(1).Infof("new conn: %s <-> %s, %d vhosts to check",
+ conn.RemoteAddr(),
+ conn.LocalAddr(),
+ len(l.vhosts))
+
+ go l.handle(conn, l.logger.AppendPrefix(fmt.Sprintf("<->%s", conn.RemoteAddr())))
}
}
writer: conn,
}
- for _, vhost := range l.vhosts {
+ l.vhostsMu.Lock()
+ defer l.vhostsMu.Unlock()
+ for vhost := range l.vhosts {
if vhost.matcher.MatchesName(clientHello.ServerName) {
- vhost.connChan <- sniConn{childConn, nil}
+ logger.V(1).Infof(
+ "new conn matched vhost %+v for name %q, dispatch",
+ vhost.matcher.Names(),
+ clientHello.ServerName)
+ vhost.msg(childConn, nil)
return
}
}
- logger.Warningf("no servername match: %q", clientHello.ServerName)
+ l.logger.Warningf("no servername match: %q", clientHello.ServerName)
conn.Close()
}
logger: l.logger,
}
- l.vhosts = append(l.vhosts, vhost)
+ l.vhostsMu.Lock()
+ defer l.vhostsMu.Unlock()
+ l.vhosts[vhost] = struct{}{}
return vhost, nil
}
return l.listener.Close()
}
+func (l *SNIListener) deregister(v *sniVirtualHost) {
+ v.logger.V(1).Debugf("deregistering SNI vhost %+v from listener %v", v.matcher.Names(), l.Addr())
+
+ l.vhostsMu.Lock()
+ defer l.vhostsMu.Unlock()
+
+ delete(l.vhosts, v)
+}
+
func (v *sniVirtualHost) Accept() (net.Conn, error) {
c := <-v.connChan
+
+ if c.conn == nil && c.err == nil {
+ return nil, net.ErrClosed
+ }
+
if c.conn != nil {
v.logger.V(3).Noticef("recv'd dispatched conn: %s <-> %s", c.conn.RemoteAddr(), c.conn.LocalAddr())
} else if c.err != nil {
return c.conn, nil
}
- tlsConn := tls.Server(c.conn, v.tlsConfig)
-
- if err := tlsConn.Handshake(); err != nil {
- c.conn.Close()
- return nil, err
- }
-
- return tlsConn, nil
+ return tls.Server(c.conn, v.tlsConfig), nil
}
func (v *sniVirtualHost) Addr() net.Addr {
}
func (v *sniVirtualHost) Close() error {
- return v.sniListener.Close()
+ v.shutdownOnce.Do(func() {
+ v.sniListener.deregister(v)
+
+ if v.connChan != nil {
+ v.msg(nil, net.ErrClosed)
+ close(v.connChan)
+ v.connChan = nil
+ }
+ })
+
+ return nil
+}
+
+func (v *sniVirtualHost) msg(conn net.Conn, err error) {
+ if v.connChan != nil {
+ select {
+ case v.connChan <- sniConn{conn, err}:
+ // yay message sent
+ default:
+ // oops
+ }
+ }
}
// ref: https://www.agwa.name/blog/post/writing_an_sni_proxy_in_go
--- /dev/null
+package http
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "sync"
+ "time"
+
+ rnet "go.fuhry.dev/runtime/net"
+ "go.fuhry.dev/runtime/net/dns"
+ "go.fuhry.dev/runtime/utils/log"
+)
+
+type SNIVirtualHost struct {
+ *BaseVirtualHost `yaml:",inline"`
+
+ Backend *struct {
+ Host string `yaml:"host"`
+ Port int `yaml:"port"`
+ } `yaml:"tcp_backend"`
+
+ ctx context.Context
+ logger log.Logger
+ listener net.Listener
+}
+
+func (v *SNIVirtualHost) Serve(ctx context.Context, sniServer *SNIListener) (func() error, error) {
+ if v.Backend.Port < 1 || v.Backend.Port > 65535 {
+ return nil, errors.New("backend port must be 1-65535")
+ }
+ if len(v.Aliases) < 1 {
+ return nil, errors.New("no aliases specified for TCP vhost")
+ }
+
+ v.logger = LoggerFromContext(ctx).AppendPrefix(fmt.Sprintf(".SNI<%s>", v.Aliases[0]))
+
+ sniListener, err := sniServer.AddVirtualHost(v, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ v.ctx = ctx
+ v.listener = sniListener
+
+ return v.serve, nil
+}
+
+func (v *SNIVirtualHost) Shutdown(ctx context.Context) (err error) {
+ if v.listener != nil {
+ err = v.listener.Close()
+ v.listener = nil
+ }
+ return
+}
+
+func (v *SNIVirtualHost) serve() error {
+ defer v.listener.Close()
+
+ for {
+ conn, err := v.listener.Accept()
+ if err != nil {
+ return err
+ }
+
+ go v.conn(conn)
+ }
+}
+
+func (v *SNIVirtualHost) conn(downstream net.Conn) {
+ defer downstream.Close()
+ requestID := newRequestID()
+
+ start := time.Now()
+
+ logEntry := map[string]string{
+ "t": start.Format(time.RFC3339),
+ "request_id": requestID,
+ "remote_address": downstream.RemoteAddr().String(),
+ "upstream_address": "-",
+ }
+
+ defer (func() {
+ l, _ := json.Marshal(logEntry)
+ v.logger.Info(string(l))
+ })()
+
+ addrs, err := v.resolve()
+ if err != nil {
+ v.logger.Errorf("[%s] error resolving backend %s:%d: %v",
+ requestID, v.Backend.Host, v.Backend.Port, err)
+ fmt.Fprintln(downstream, "error resolving backend")
+ return
+ }
+
+ connCtx, cancel := context.WithTimeout(
+ v.ctx,
+ 5*time.Second)
+ defer cancel()
+
+ upstream, err := rnet.DialHappyEyeballs(connCtx, addrs, uint16(v.Backend.Port))
+ if err != nil {
+ v.logger.Errorf("[%s] error establishing connection to backend %s:%d: %v",
+ requestID, v.Backend.Host, v.Backend.Port, err)
+ fmt.Fprint(downstream, "error establishing connection to backend")
+ return
+ }
+
+ logEntry["upstream_address"] = upstream.RemoteAddr().String()
+
+ defer upstream.Close()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ var tx, rx int64
+ txP := &tx
+ rxP := &rx
+
+ go (func() {
+ defer wg.Done()
+ rx, _ := io.Copy(upstream, downstream)
+ *rxP = rx
+ })()
+ go (func() {
+ defer wg.Done()
+ tx, _ := io.Copy(downstream, upstream)
+ *txP = tx
+ })()
+
+ wg.Wait()
+
+ logEntry["bytes_received"] = fmt.Sprintf("%d", rx)
+ logEntry["bytes_sent"] = fmt.Sprintf("%d", tx)
+}
+
+func (v *SNIVirtualHost) resolve() ([]net.Addr, error) {
+ if v.Backend.Host == "" {
+ return []net.Addr{
+ &net.IPAddr{IP: net.IP{127, 0, 0, 1}},
+ }, nil
+ }
+
+ _, zone, _ := strings.Cut(v.Backend.Host, "%")
+ if ip := net.ParseIP(v.Backend.Host); ip != nil {
+ return []net.Addr{
+ &net.IPAddr{IP: ip, Zone: zone},
+ }, nil
+ }
+
+ ip4, ip6, err := dns.ResolveDualStack(v.Backend.Host)
+ if err != nil {
+ return nil, err
+ }
+
+ var out []net.Addr
+ if ip4 != "" {
+ out = append(out, &net.IPAddr{IP: net.ParseIP(ip4)})
+ }
+ if ip6 != "" {
+ out = append(out, &net.IPAddr{IP: net.ParseIP(ip6)})
+ }
+
+ return out, nil
+}