From fc0e309aa4b1786937947255f6ec1ee8297b32fa Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Fri, 15 Aug 2025 13:01:16 -0400 Subject: [PATCH] [http] SNI proxying, healthcheck action, populate authorization, better handle nested authorization, docs - added readme, because this is seriously getting too complicated for me to configure from memory - refactored connection acceptance with a new `net.Listener` implementation that supports SNI, with virtual listeners that can be used with `http.Server` - foundations laid for bare-TCP SNI proxying too, but implementation not started yet - added `healthcheck` action - added a mutable request context at the logging middleware level, to bubble up request data to the logger - propagate SAML authorization state to global request context - SAML action now skips if request was previously authorized - got multiple listeners + multiple vhosts per listener working No breaking config changes. --- http/README.md | 252 +++++++++++++++++ http/http.go | 177 ++++++++++++ http/proxy/main.go | 31 +-- http/route_action_healthcheck.go | 32 +++ http/route_action_saml.go | 47 +++- http/server.go | 457 +++++++++++++++++++------------ http/sni_listener.go | 277 +++++++++++++++++++ mtls/fsnotify/fsnotify.go | 2 +- utils/log/http.go | 47 +++- utils/log/log.go | 7 +- 10 files changed, 1123 insertions(+), 206 deletions(-) create mode 100644 http/README.md create mode 100644 http/http.go create mode 100644 http/route_action_healthcheck.go create mode 100644 http/sni_listener.go diff --git a/http/README.md b/http/README.md new file mode 100644 index 0000000..9fa9558 --- /dev/null +++ b/http/README.md @@ -0,0 +1,252 @@ +# HTTP sidecar proxy + +OK, this is a little bit of a mess right now, but I needed a sidecar proxy that was tailored to my needs: S3 front-ending, SAML, simple redirection rules, proxy protocol v2, and of course direct proxying to a backend. + +## Architecture + +The server supports an arbitrary number of listeners and an arbitrary number of virtual hosts per listener. Each virtual host can have its own certificate and key pair; the correct virtual host will be selected based on the TLS Server Name Indication (SNI). If a virtual host receives an HTTP request with a Host header that does not match any of its configured aliases, the server responds with an HTTP 421 Misdirected Request status code. + +Each virtual host defines an ordered list of routes, each of which consists of a match rule and an action. + +Actions implement the `RouteAction` interface, which consists of a single method: + +```golang +Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) +``` + +### Startup + +Each route action is registered using `AddRouteParseFunc`. The function passed to `AddRouteParseFunc` returns: + +* `nil, nil` if the route doesn't have settings for that handler +* `nil, error` if the route has settings for that handler and there was a parse error +* `RouteAction, nil` if the route was successfully configured + +### Request path + +When a request arrives, the server walks each loaded route. The URI path (with leading slash, and without the query string) is used as the string matcher input. If the request matches the route, the action is called. + +The action uses a middleware pattern, i.e., it has the option of either handling the request directly, or continuing request processing with the next (matching) route. The action may optionally mutate the request, i.e. by adding/removing headers, rewriting the path, adding context variables, etc., or can forward the request to the next action verbatim. + +This means that the interface for filters (which mutate or transform a request before passing it along) and handlers (which serve a response) is identical and no distinction is made by the server between filters and handlers. Some route actions act as both, i.e., the `saml` action acts as a handler when serving an IdP redirect or the assertion consumer service, but acts as a filter if the request is authorized. + +If a route matches the request but no action is configured, the server responds with `404 Not Found` and a "no action configured for route" error in the body. + +If no route matches the request URI, the server responds with a plain `404 Not Found` status and body. + +If the incoming request doesn't match the configured virtual host, the server sends the `421 Misdirected Request` status code. + +### Authorization + +Authorization actions (currently only `saml`) are expected to check the request authorization state by calling `AuthorizationFromContext`. If this returns a non-nil value, the authorization handler should consider the request to have been authorized by a previous handler. + +This means that the **first** matching route defining an authorization handler is authoritative for a request. Routes should thus be ordered with deeper/more specific routes first, and top-level/fallback routes last. + +## Example config + +This shows the complete YAML schema, including all optional keys. + +```yaml +listener: + listen: "[::]:443" + listen_insecure: "[::]:80" + external_port: 443 + proxy_protocol: true + # this is passed to go.fuhry.dev/runtime/mtls.LoadSSLCertificateFromFilesystem + # default paths: + # leaf and intermediates: /etc/ssl/private/www.example.com/fullchain.pem + # private key: /etc/ssl/private/www.example.com/privkey.pem + # "/etc/ssl/private" is the default value of the `-tls.certs-dir` flag + certificate: www.example.com + trust_upstream_request_id: false + virtual_hosts: + www.example.com: + routes: + # Route type: Healthcheck + - path: + mode: exact + value: /healthz + healthcheck: {} + # Route type: SAML + - path: + mode: prefix + value: / + auth: saml + saml: + require: true + # if false, any cookie starting with "saml_" is stripped from the request + # before the next route is processed + preserve_cookies: false + sp: + # same settings as `saml` top level key + entity_id: ... + # Route type: redirect + - path: + mode: exact + value: /legacy/route + redirect: + dest: /new/route + status: 307 + # Route type: proxy + - path: + mode: prefix + value: / + proxy: + host: 127.0.0.1 + port: 8999 + # Expected mTLS ID for the backend server. If omitted, connects to the + # backend over plain http. + mtls_id: foo + # Route type: s3 + - path: + mode: prefix + value: /assets/ + s3: + endpoint: s3.us-east-1.amazonaws.com + access_key: AKIAZZZZZZZZZZZZZZZZZZZZZZZZZZZ + secret_key: XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX + bucket: example + prefix: /example.com/static/ + # optional; if specified, this string is trimmed from the front of the URI + # path + strip_prefix: /assets/ + # with this example configuration, a request to + # /assets/images/example.png + # would map to + # /example.com/static/images/example.png +saml: + # Global SAML configuration; may be overridden at the listener or route level + entity_id: "urn:www.example.com" + entity_certificate: /data/saml.crt + entity_key: /data/saml.key + # URL to IdP metadata + idp: https://auth.example.com/saml2/metadata +``` + +## Matchers + +Syntax: + +```yaml +match: + mode: matcher_type + # for all modes except "any", "never", "and" and "or": + value: value_to_match + # only for "and" and "or" modes: + rules: + - mode: matcher_type + value: value_one + - mode: matcher_type + value: value_two +``` + +Valid `mode`s are: + +* `any` (does not accept a value) +* `never` (does not accept a value) +* `prefix` +* `suffix` +* `exact` +* `contains` +* `regexp` +* `never` +* `and` +* `or` + +Note that `and` and `or` matchers can be infinitely nested. + +Reference: [go.fuhry.dev/runtime/utils/stringmatch/serialization.go](../utils/stringmatch/serialization.go) + +## Actions + +Each route must have **exactly one** action. Server behavior if multiple actions are specified is undefined. + +### Healthcheck + +**Key:** `healthcheck` + +The healthcheck action simply returns a `200 OK` response with the body containing the string `LIVE` followed by a newline (LF). + +In the future, this endpoint will return a `502` status with the body set to a descriptive string if the server should not accept requests. + +No additional configuration. + +_Example:_ `healthcheck: {}` + +### SAML + +**Keys:** `auth` (must be set to `saml`); `saml` + +Authorization action which allows the proxy to function as a SAML service provider (SP). + +If there are multiple matching routes with SAML actions for a request, the first one to match is the one that is authoritative for the request. + +If a valid session is present, the `saml` action sets following request headers: + +* `on-behalf-of`: the value of the "uid" attribute from the assertion, if present, otherwise the NameID +* `x-saml-audience`: The service provider's entity ID +* `x-saml-issued-at`: The `NotBefore` attribute of the SAML assertion, formatted as a UNIX timestamp. +* `x-saml-expires-at`: The `NotAfter` attribute of the SAML assertion, formatted as a UNIX timestamp. +* `x-saml-subject`: NameID in the assertion, which is often either an email address or opaque UUID +* Any attributes from the assertion are normalized and forwarded as follows: + 1. The attribute name is transformed to lowercase + 2. Any sequence of one or more non-alphanumeric characters is replaced with a single dash (`-`). + 3. If there are multiple attributes with the same name, the values are sent as a single concatenated HTTP header, separated by a comma followed by a space +* `x-saml-anonymous-auth`: set to `1` if `required` is `false` and no valid session is present; omitted otherwise. + +**Important:** The URI `/saml/acs` **MUST** be handled by `saml`. This can be accomplished with either an `exact` matcher with the value `/saml/acs`, or `prefix` of `/` or `/saml`. This path is hardcoded and not presently configurable. + +Settings: + +* `require` _(bool)_: If true, and no valid session is present, the request is redirected to the IdP to complete authentication. + +### Redirect + +**Key:** `redirect` + +Redirects matching requests to the specified destination. + +Settings: + +* `dest` _(string, **required**)_: Absolute or relative URI to redirect to. +* `status` _(int)_: Custom HTTP status code to use. Defaults to 302. + + +### Proxy + +**Key:** `proxy` + +Forwards the request to the specified backend. + +Settings: + +* `port` _(int, **required**)_: Destination port +* `host`: Destination hostname or IP; defaults to `127.0.0.1` + +### S3 + +**Key:** `s3` + +Forwards GET requests to an Amazon S3-compatible object store. + +Settings: + +* `endpoint` (string, **required**): S3 endpoint. Currently, only path-style requests are supported. Example for AWS S3: `s3.us-east-1.amazonaws.com` +* `access_key` (string, **required**): IAM access key +* `secret_key` (string, **required**): IAM secret key +* `bucket` (string, **required**): Bucket name +* `prefix` (string): static prefix to prepend to the object path, i.e., subpath under the bucket root +* `strip_prefix` (string): static prefix to remove from the beginning of the object path. Prefix matches are not automatically stripped. + +## Request ID support + +All requests are assigned a random request ID, which is a 20 digit random hex string. + +If the `trust_upstream_request_id` listener setting is `true`, and the incoming request contains the `x-request-id` header, the value of that header is reused. Otherwise, a new request ID is generated at the start of request processing. + +## Roadmap + +Things that aren't currently implemented, and need to be: + +* Loading IdP metadata from filesystem instead of HTTP URL +* Draining support \ No newline at end of file diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..1dfb1f3 --- /dev/null +++ b/http/http.go @@ -0,0 +1,177 @@ +package http + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "regexp" + "strings" + + "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/utils/log" +) + +type HTTPVirtualHost struct { + *BaseVirtualHost `yaml:",inline"` + + Routes []*Route `yaml:"routes"` + Certificate string `yaml:"certificate"` + TrustUpstreamRequestID *bool `yaml:"trust_upstream_request_id"` + + tlsServer *http.Server +} + +func (v *HTTPVirtualHost) Serve(ctx context.Context, sniServer *SNIListener) (func() error, error) { + server, err := v.NewHTTPServerWithContext(ctx) + if err != nil { + return nil, err + } + + if server.TLSConfig == nil { + return nil, errors.New("no tls config set on http vhost") + } + + sniListener, err := sniServer.AddVirtualHost(v, server.TLSConfig) + if err != nil { + return nil, err + } + v.tlsServer = server + + return func() error { + return server.Serve(sniListener) + }, nil +} + +// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host +// and other settings. +func (v *HTTPVirtualHost) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) { + logger := LoggerFromContext(ctx) + serverCtx := context.WithValue(ctx, kLogger, logger) + + lm := log.NewLoggingMiddlewareWithLogger( + http.HandlerFunc(v.handle), + logger.AppendPrefix(".access").WithLevel(log.INFO)) + lm.ExtraFunc = v.injectLogAttrs + + server := &http.Server{ + Addr: "", + BaseContext: func(l net.Listener) context.Context { + return context.WithValue(serverCtx, kListenAddr, l.Addr()) + }, + Handler: lm.HandlerFunc(), + } + + if v.Certificate != "" { + cert := mtls.NewSSLCertificate(v.Certificate) + tlsConfig, err := cert.TlsConfig(serverCtx) + if err != nil { + return nil, err + } + tlsConfig.NextProtos = []string{"h2"} + server.TLSConfig = tlsConfig + } + + return server, nil +} + +func (v *HTTPVirtualHost) Shutdown(ctx context.Context) error { + if v.tlsServer != nil { + return v.tlsServer.Shutdown(ctx) + } + + return nil +} + +func (v *HTTPVirtualHost) handle(w http.ResponseWriter, r *http.Request) { + // ensure host header present + if r.Host == "" { + r.Header.Write(os.Stderr) + http.Error(w, "missing Host header", http.StatusBadRequest) + return + } + + rid := r.Header.Get("x-request-id") + if rid == "" || !*v.TrustUpstreamRequestID { + rid = newRequestID() + } + w.Header().Set("x-request-id", rid) + log.StateFromContext(r.Context()).Set(kRequestID, rid) + + if !v.MatchesName(r.Host) && !v.MatchesName(portSpec.ReplaceAllLiteralString(r.Host, "")) { + http.Error(w, "Misdirected request: unknown virtual host", + http.StatusMisdirectedRequest) + return + } + + v.fulfill(w, r, v.Routes) +} + +func vhostGlobToRegexp(name string) (*regexp.Regexp, error) { + var parts []string + for _, part := range strings.Split(name, ".") { + switch { + case part == singleHostnameToken: + parts = append(parts, hostnameWildcardSingle) + case part == multiHostnameToken: + parts = append(parts, hostnameWildcardMulti) + case domainComponentSpec.MatchString(part): + parts = append(parts, part) + default: + return nil, fmt.Errorf("invalid virtual host name or alias: %q", name) + } + } + return regexp.Compile("^" + strings.Join(parts, `\.`) + "$") +} + +func (v *HTTPVirtualHost) fulfill(w http.ResponseWriter, r *http.Request, routes []*Route) { + logger := LoggerFromContext(r.Context()) + if logger == nil { + http.Error(w, "cannot get logger", http.StatusInternalServerError) + } + + logger.V(3).Debugf("checking for routes matching %s", r.URL) + for i, route := range routes { + match := false + if route.Path != nil { + match = route.Path.Match(r.URL.Path) + logger.V(3).Debugf("path %s matches %s: %t", + r.URL.Path, route.Path.String(), match) + } else { + http.Error(w, "nothing to match on in route", http.StatusInternalServerError) + } + + if match { + if route.Action != nil { + logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action) + next := http.NotFound + if len(routes) > i { + next = func(w http.ResponseWriter, r *http.Request) { + logger.V(3).Debugf("%T called next(), continuing request processing", route.Action) + v.fulfill(w, r, routes[i+1:]) + } + } + route.Action.Handle(w, r, next) + return + } else { + http.Error(w, + fmt.Sprintf("no action configured for route %s", route.Path.String()), + http.StatusInternalServerError) + } + } + } + + http.NotFound(w, r) +} + +func (v *HTTPVirtualHost) injectLogAttrs(entry map[string]any, w http.ResponseWriter, r *http.Request) { + if authz := AuthorizationFromContext(r.Context()); authz != nil { + entry["authorized_by"] = authz.Authorizer + entry["user"] = authz.Principal + } + if rid, ok := log.StateFromContext(r.Context()).Get(kRequestID).(string); ok { + entry["request_id"] = rid + } +} diff --git a/http/proxy/main.go b/http/proxy/main.go index 25e610f..b4eb445 100644 --- a/http/proxy/main.go +++ b/http/proxy/main.go @@ -22,33 +22,29 @@ func main() { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - server := http.NewServerWithContext(ctx) + configPath := flag.String("config", "", "YAML file to load configuration from") + flag.Parse() - loadConfig := func(arg string) error { - contents, err := os.ReadFile(arg) - if err != nil { - return err - } + server := http.NewServerWithContext(ctx) - err = yaml.Unmarshal(contents, server) - return err + contents, err := os.ReadFile(*configPath) + if err != nil { + log.Default().Panic(err) } - flag.Func("config", "YAML file to load configuration from", loadConfig) - flag.StringVar(&server.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private") - flag.StringVar(&server.Listener.Addr, "listen", "[::]:8443", "address for auth proxy to listen on") - flag.StringVar(&server.Listener.InsecureAddr, "listen.http", "[::]:8080", "address for http-to-https redirector") - - flag.Parse() + err = yaml.Unmarshal(contents, server) + if err != nil { + log.Default().Panic(err) + } go (func() { - if err := server.ListenAndServeTLS(); err != nil { + if err := server.ListenAndServeTLS(); err != nil && err != http.ErrServerClosed { log.Panic(err) } })() go (func() { - if err := server.ListenAndServe(); err != nil { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Panic(err) } })() @@ -56,7 +52,10 @@ func main() { daemon.SdNotify(false, daemon.SdNotifyReady) <-ctx.Done() + + log.Default().Noticef("Server shutting down") shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) defer shutdownCancel() server.Shutdown(shutdownCtx) + log.Default().Noticef("Exiting normally") } diff --git a/http/route_action_healthcheck.go b/http/route_action_healthcheck.go new file mode 100644 index 0000000..4225b7b --- /dev/null +++ b/http/route_action_healthcheck.go @@ -0,0 +1,32 @@ +package http + +import ( + "net/http" + + "gopkg.in/yaml.v3" +) + +type HealthCheckAction struct{} + +func (a *HealthCheckAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + w.WriteHeader(http.StatusOK) + w.Header().Set("content-type", "text/plain") + + w.Write([]byte("LIVE\n")) +} + +func healthCheckFromRouteYaml(node *yaml.Node) (RouteAction, error) { + var rawNode struct { + HealthCheck map[string]any `yaml:"healthcheck,omitempty"` + } + + if err := node.Decode(&rawNode); err == nil && rawNode.HealthCheck != nil { + return &HealthCheckAction{}, nil + } + + return nil, nil +} + +func init() { + AddRouteParseFunc(healthCheckFromRouteYaml) +} diff --git a/http/route_action_saml.go b/http/route_action_saml.go index 6594fce..9589651 100644 --- a/http/route_action_saml.go +++ b/http/route_action_saml.go @@ -53,6 +53,13 @@ var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"}) func (sa *samlAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { logger := LoggerFromContext(r.Context()) + // if request was already authorized, do nothing + if authz := AuthorizationFromContext(r.Context()); authz != nil { + logger.V(3).Debugf("req already authorized for %q by %q, skip saml", authz.Principal, authz.Authorizer) + next(w, r) + return + } + // ensure client isn't trying to inject saml-related headers if err := sa.checkRequest(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -107,14 +114,12 @@ func (sa *samlAction) Handle(w http.ResponseWriter, r *http.Request, next http.H return } + var principal string if swa, ok := session.(samlsp.SessionWithAttributes); ok { attrs := swa.GetAttributes() - oboHeader := sa.usernameHeader - if oboHeader == "" { - oboHeader = "on-behalf-of" + if uid := attrs.Get("uid"); uid != "" { + principal = uid } - logger.V(3).Debugf("setting origin request header: %s: %q", oboHeader, attrs.Get("uid")) - r.Header.Set(oboHeader, attrs.Get("uid")) } if jwts, ok := session.(samlsp.JWTSessionClaims); ok { @@ -131,6 +136,9 @@ func (sa *samlAction) Handle(w http.ResponseWriter, r *http.Request, next http.H r.Header.Set("x-saml-subject", jwts.StandardClaims.Subject) logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject) + if principal == "" { + principal = jwts.StandardClaims.Subject + } for attr, values := range jwts.Attributes { headerName := fmt.Sprintf("x-saml-%s", @@ -140,6 +148,13 @@ func (sa *samlAction) Handle(w http.ResponseWriter, r *http.Request, next http.H headerName, headerValue) r.Header.Set(headerName, headerValue) } + + oboHeader := sa.usernameHeader + if oboHeader == "" { + oboHeader = "on-behalf-of" + } + logger.V(3).Debugf("setting origin request header: %s: %s", oboHeader, principal) + r.Header.Set(oboHeader, principal) } else { r.Header.Set("x-saml-anonymous-auth", "1") } @@ -158,6 +173,11 @@ func (sa *samlAction) Handle(w http.ResponseWriter, r *http.Request, next http.H r = newReq } + SetAuthorizationContext(r.Context(), &Authorization{ + Authorizer: "saml", + Principal: principal, + }) + next(w, r) } @@ -363,15 +383,22 @@ func samlActionFromRouteYaml(node *yaml.Node) (RouteAction, error) { return nil, nil } - require, err := strconv.ParseBool(rawNode.SP.Require) - if err != nil { - return nil, err + var spConfig *SAMLServiceProvider + require := true + preserveCookies := false + if rawNode.SP != nil { + require, err = strconv.ParseBool(rawNode.SP.Require) + if err != nil { + return nil, err + } + spConfig = rawNode.SP.SAMLServiceProvider + preserveCookies = rawNode.SP.PreserveCookies } sa := &samlAction{ - sp: rawNode.SP.SAMLServiceProvider, + sp: spConfig, requireAuth: require, - preserveCookies: rawNode.SP.PreserveCookies, + preserveCookies: preserveCookies, } return sa, nil diff --git a/http/server.go b/http/server.go index b845776..1194124 100644 --- a/http/server.go +++ b/http/server.go @@ -2,19 +2,20 @@ package http import ( "context" - "crypto/tls" "encoding/hex" "errors" "fmt" "math/rand" "net" "net/http" - "os" "regexp" + "slices" + "strconv" + "strings" "sync" "time" - "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/utils/hashset" "go.fuhry.dev/runtime/utils/log" "go.fuhry.dev/runtime/utils/stringmatch" "gopkg.in/yaml.v3" @@ -31,25 +32,50 @@ type Route struct { Action RouteAction } -type VirtualHost struct { - Routes []*Route `yaml:"routes"` +type NameMatcher interface { + MatchesName(string) bool +} + +type VirtualHost interface { + NameMatcher + + Serve(context.Context, *SNIListener) (func() error, error) + Shutdown(context.Context) error +} + +type BaseVirtualHost struct { + Aliases []string `yaml:"aliases"` + + aliasesExact *hashset.HashSet[string] + aliasesCompiled []*regexp.Regexp } type Listener struct { - Addr string `yaml:"listen"` - ProxyProtocol bool `yaml:"proxy_protocol"` - InsecureAddr string `yaml:"listen_insecure"` - Certificate string `yaml:"cert"` - TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"` - VirtualHosts map[string]*VirtualHost `yaml:"virtual_hosts"` + // Addr is the IP address and port to listen on for TLS connections. + Addr string `yaml:"listen"` + // ExternalPort is an optional port to use when sending http-to-https redirects. If external traffic + // enters through a different port than what the server is bound to locally, this setting ensures that + // redirects reference the correct port. + ExternalPort int `yaml:"external_port"` + ProxyProtocol bool `yaml:"proxy_protocol"` + InsecureAddr string `yaml:"listen_insecure"` + Certificate string `yaml:"cert"` + TrustUpstreamRequestID bool `yaml:"trust_upstream_request_id"` + VirtualHosts []VirtualHost `yaml:"-"` + HTTPVirtualHosts map[string]*HTTPVirtualHost `yaml:"virtual_hosts"` + + sniServer *SNIListener + httpServer *http.Server } type Server struct { - Listener *Listener `yaml:"listener"` - Context context.Context `yaml:"-"` + Listeners []*Listener `yaml:"listeners"` + Context context.Context `yaml:"-"` +} - tlsServer *http.Server - httpServer *http.Server +type Authorization struct { + Authorizer string + Principal string } type initHook func(context.Context, *yaml.Node) (context.Context, error) @@ -59,12 +85,24 @@ const ( kLogger serverCtxVar = iota kListener kListenAddr + kVirtualHost kSamlDefaults kRequestID + kAuthorization ) +const ( + singleHostnameToken = `*` + multiHostnameToken = `**` + + hostnameWildcardSingle = `[A-Za-z0-9-]+` + hostnameWildcardMulti = `[A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)*` +) + +var ErrServerClosed = http.ErrServerClosed var randSrc = rand.New(rand.NewSource(time.Now().UnixNano())) var portSpec = regexp.MustCompile(":[0-9]{1,5}$") +var domainComponentSpec = regexp.MustCompile("^" + hostnameWildcardSingle + "$") var initHooks []initHook var routeParseFuncs []routeParseFunc @@ -84,9 +122,6 @@ func NewServerWithContext(ctx context.Context) *Server { logger := log.WithPrefix(fmt.Sprintf("%T", &Server{})) return &Server{ - Listener: &Listener{ - VirtualHosts: make(map[string]*VirtualHost, 0), - }, Context: context.WithValue(ctx, kLogger, logger), } } @@ -127,9 +162,10 @@ func (r *Route) UnmarshalYAML(node *yaml.Node) error { // UnmarshalYAML implements yaml.Unmarshaler func (s *Server) UnmarshalYAML(node *yaml.Node) error { - lc := &struct { - Listener *Listener `yaml:"listener"` - }{} + var lc struct { + Listener *Listener `yaml:"listener"` + Listeners []*Listener `yaml:"listeners"` + } if s.Context == nil { s.Context = context.Background() @@ -139,7 +175,42 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error { return err } - s.Listener = lc.Listener + if lc.Listener != nil && len(lc.Listeners) > 0 { + return errors.New("both \"listener\" and \"listeners\" set in config yaml, please set only one") + } + + if lc.Listener != nil { + lc.Listeners = []*Listener{lc.Listener} + } + + s.Listeners = lc.Listeners + + for _, listener := range s.Listeners { + for name, vhost := range listener.HTTPVirtualHosts { + if vhost.BaseVirtualHost == nil { + vhost.BaseVirtualHost = &BaseVirtualHost{} + } + vhost.Aliases = append(vhost.Aliases, name) + listener.VirtualHosts = append(listener.VirtualHosts, vhost) + } + + for i, vhost := range listener.VirtualHosts { + switch vhost := vhost.(type) { + case *HTTPVirtualHost: + if vhost.Certificate == "" { + vhost.Certificate = listener.Certificate + } + if vhost.TrustUpstreamRequestID == nil { + vhost.TrustUpstreamRequestID = &listener.TrustUpstreamRequestID + } + if err := vhost.precompileAliases(); err != nil { + return fmt.Errorf("bootstrapping vhost %d: %v", i, err) + } + default: + return fmt.Errorf("unsupported vhost type: %T", vhost) + } + } + } for _, initHook := range initHooks { newCtx, err := initHook(s.Context, node) @@ -152,219 +223,246 @@ func (s *Server) UnmarshalYAML(node *yaml.Node) error { return nil } -func (s *Server) ListenAndServeTLS() error { - listenerCtx := context.WithValue(s.Context, kListener, s.Listener) - server, err := s.Listener.NewHTTPServerWithContext(listenerCtx) - if err != nil { - return err - } +func (s *Server) ListenAndServeTLS() (err error) { + wg := &sync.WaitGroup{} - var listener net.Listener - if s.Listener.ProxyProtocol { - listener, err = ListenProxyProtocol("tcp", server.Addr) - LoggerFromContext(listenerCtx).Noticef( - "Listening for standard or PROXY protocol TLS connections on %s", listener.Addr(), - ) - } else { - listener, err = net.Listen("tcp", server.Addr) - LoggerFromContext(listenerCtx).Noticef( - "Listening for TLS connnections on %s", listener.Addr(), - ) - } - if err != nil { - return err - } - if server.TLSConfig != nil { - listener = tls.NewListener(listener, server.TLSConfig) + for _, listener := range s.Listeners { + listenerCtx := context.WithValue(s.Context, kListener, listener) + + wg.Add(1) + go (func() { + defer wg.Done() + err = listener.ListenAndServeSNI(listenerCtx) + })() } - s.tlsServer = server - return server.Serve(listener) + wg.Wait() + return err } -func (s *Server) ListenAndServe() error { - var listener net.Listener - var err error +func (s *Server) ListenAndServe() (err error) { + wg := &sync.WaitGroup{} + for _, listener := range s.Listeners { + var netListener net.Listener + + listenerCtx := context.WithValue(s.Context, kListener, listener) + server := listener.NewHTTPSRedirectorWithContext(listenerCtx) + if listener.ProxyProtocol { + netListener, err = ListenProxyProtocol("tcp", server.Addr) + if err != nil { + return err + } + LoggerFromContext(listenerCtx).Noticef( + "Listening for standard or PROXY protocol connections on %s", netListener.Addr(), + ) + } else { + netListener, err = net.Listen("tcp", server.Addr) + if err != nil { + return err + } + LoggerFromContext(listenerCtx).Noticef( + "Listening on %s", netListener.Addr(), + ) + } - listenerCtx := context.WithValue(s.Context, kListener, s.Listener) - server := s.Listener.NewHTTPSRedirectorWithContext(listenerCtx) - if s.Listener.ProxyProtocol { - listener, err = ListenProxyProtocol("tcp", server.Addr) - LoggerFromContext(listenerCtx).Noticef( - "Listening for standard or PROXY protocol connections on %s", listener.Addr(), - ) - } else { - listener, err = net.Listen("tcp", server.Addr) - LoggerFromContext(listenerCtx).Noticef( - "Listening on %s", listener.Addr(), - ) - } - if err != nil { - return err + listener.httpServer = server + wg.Add(1) + go (func() { + defer wg.Done() + err = server.Serve(netListener) + })() } - s.httpServer = server - return server.Serve(listener) + wg.Wait() + return err } func (s *Server) Shutdown(shutdownCtx context.Context) error { var wg sync.WaitGroup var err error - if s.httpServer != nil { - wg.Add(1) + for _, listener := range s.Listeners { + if listener.httpServer != nil { + wg.Add(1) + + go (func() { + defer wg.Done() + err = listener.httpServer.Shutdown(shutdownCtx) + })() + } - go (func() { - defer wg.Done() - err = s.httpServer.Shutdown(shutdownCtx) - })() - } - if s.tlsServer != nil { - wg.Add(1) - go (func() { - defer wg.Done() - err = s.tlsServer.Shutdown(shutdownCtx) - })() + if listener.sniServer != nil { + wg.Add(1) + go (func() { + defer wg.Done() + err = listener.sniServer.Close() + })() + + for _, vhost := range listener.VirtualHosts { + wg.Add(1) + go (func() { + defer wg.Done() + err = vhost.Shutdown(shutdownCtx) + })() + } + } } wg.Wait() return err } -// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host -// and other settings. -func (l *Listener) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) { +func (l *Listener) ListenAndServeSNI(ctx context.Context) (err error) { if l.Addr == "" { l.Addr = "[::]:8443" } logger := LoggerFromContext(ctx).WithPrefix(fmt.Sprintf("%T(%s)", l, l.Addr)) - serverCtx := context.WithValue(ctx, kLogger, logger) - lm := log.NewLoggingMiddlewareWithLogger( - http.HandlerFunc(l.handle), - logger.AppendPrefix(".access")) - lm.AddResponseHeader("x-request-id") + if len(l.VirtualHosts) < 1 { + return errors.New("listener has no virtual hosts configured") + } - server := &http.Server{ - Addr: l.Addr, - BaseContext: func(l net.Listener) context.Context { - return context.WithValue(serverCtx, kListenAddr, l.Addr()) - }, - Handler: lm.HandlerFunc(), + var netListener net.Listener + if l.ProxyProtocol { + netListener, err = ListenProxyProtocol("tcp", l.Addr) + if err != nil { + return err + } + logger.Noticef( + "Listening for standard or PROXY protocol TLS connections on %s", netListener.Addr(), + ) + } else { + netListener, err = net.Listen("tcp", l.Addr) + if err != nil { + return err + } + logger.Noticef( + "Listening for TLS connnections on %s", netListener.Addr(), + ) } - if l.Certificate != "" { - cert := mtls.NewSSLCertificate(l.Certificate) - tlsConfig, err := cert.TlsConfig(serverCtx) + l.sniServer = NewSNIListener(netListener) + + wg := &sync.WaitGroup{} + for _, vhost := range l.VirtualHosts { + serveFunc, err := vhost.Serve(ctx, l.sniServer) if err != nil { - return nil, err + return err } - tlsConfig.NextProtos = []string{"h2"} - server.TLSConfig = tlsConfig + + wg.Add(1) + go (func() { + defer wg.Done() + serveFunc() + })() } - return server, nil + wg.Add(1) + go (func() { + defer wg.Done() + err = l.sniServer.Serve() + })() + + return err } func (l *Listener) NewHTTPSRedirectorWithContext(ctx context.Context) *http.Server { if l.InsecureAddr == "" { l.InsecureAddr = "[::]:8080" } + logger := LoggerFromContext(ctx) - server := &http.Server{ - Addr: l.InsecureAddr, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := r.Host - if host == "" { - w.WriteHeader(http.StatusBadRequest) - return - } - - if _, ok := l.VirtualHosts[host]; !ok { - w.WriteHeader(http.StatusMisdirectedRequest) - return - } + lm := log.NewLoggingMiddlewareWithLogger( + http.HandlerFunc(l.handleRedirectToHTTPS), + logger.AppendPrefix(".access").WithLevel(log.INFO)) - newUrl := *r.URL - newUrl.Scheme = "https" - newUrl.Host = host - w.Header().Set("location", newUrl.String()) - w.WriteHeader(http.StatusFound) - }), + server := &http.Server{ + Addr: l.InsecureAddr, + Handler: lm.HandlerFunc(), } return server } -func (l *Listener) handle(w http.ResponseWriter, r *http.Request) { - // ensure host header present - if r.Host == "" { - r.Header.Write(os.Stderr) - http.Error(w, "missing Host header", http.StatusBadRequest) +func (l *Listener) handleRedirectToHTTPS(w http.ResponseWriter, r *http.Request) { + host := r.Host + if host == "" { + w.WriteHeader(http.StatusBadRequest) return } - reqIdCtx := r.Context() - rid := r.Header.Get("x-request-id") - if rid == "" || !l.TrustUpstreamRequestID { - rid = newRequestID() - } - w.Header().Set("x-request-id", rid) - reqIdCtx = context.WithValue(reqIdCtx, kRequestID, rid) - - // make sure this host is known - vhost, ok := l.VirtualHosts[r.Host] - if !ok { - h := portSpec.ReplaceAllString(r.Host, "") - vhost, ok = l.VirtualHosts[h] - if !ok { - http.Error(w, "Misdirected request: unknown virtual host", - http.StatusMisdirectedRequest) - return + if vhost := l.lookupHost(host); vhost != nil { + _, securePort, _ := net.SplitHostPort(l.Addr) + if l.ExternalPort != 0 { + securePort = strconv.Itoa(l.ExternalPort) } + destHost := portSpec.ReplaceAllLiteralString(host, "") + if securePort != "443" { + destHost = net.JoinHostPort(destHost, securePort) + } + newUrl := *r.URL + newUrl.Scheme = "https" + newUrl.Host = destHost + + w.Header().Set("location", newUrl.String()) + w.WriteHeader(http.StatusFound) + return } - l.fulfill(w, r.WithContext(reqIdCtx), vhost.Routes) + w.WriteHeader(http.StatusMisdirectedRequest) } -func (l *Listener) fulfill(w http.ResponseWriter, r *http.Request, routes []*Route) { - logger := LoggerFromContext(r.Context()) - if logger == nil { - http.Error(w, "cannot get logger", http.StatusInternalServerError) +func (l *Listener) lookupHost(hostname string) VirtualHost { + for _, vhost := range l.VirtualHosts { + if vhost.MatchesName(hostname) { + return vhost + } + } + + // if no match, strip port from hostname and try again + if portSpec.MatchString(hostname) { + return l.lookupHost(portSpec.ReplaceAllString(hostname, "")) } - logger.V(3).Debugf("checking for routes matching %s", r.URL) - for i, route := range routes { - match := false - if route.Path != nil { - match = route.Path.Match(r.URL.Path) - logger.V(3).Debugf("path %s matches %s: %t", - r.URL.Path, route.Path.String(), match) - } else { - http.Error(w, "nothing to match on in route", http.StatusInternalServerError) - } + return nil +} - if match { - if route.Action != nil { - logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action) - next := http.NotFound - if len(routes) > i { - next = func(w http.ResponseWriter, r *http.Request) { - logger.V(3).Debugf("%T called next(), continuing request processing", route.Action) - l.fulfill(w, r, routes[i+1:]) - } +func (v *BaseVirtualHost) precompileAliases() error { + v.aliasesExact = hashset.NewHashSet[string]() + for i, a := range v.Aliases { + parts := strings.Split(a, ".") + if slices.Contains(parts, singleHostnameToken) || slices.Contains(parts, multiHostnameToken) { + re, err := vhostGlobToRegexp(a) + if err != nil { + return fmt.Errorf("compling vhost alias with index %d (%q): %v", i, a, err) + } + v.aliasesCompiled = append(v.aliasesCompiled, re) + } else { + for _, part := range parts { + if !domainComponentSpec.MatchString(part) { + return fmt.Errorf("vhost alias with index %d (%q) is not a valid domain name", i, a) } - route.Action.Handle(w, r, next) - return - } else { - http.Error(w, - fmt.Sprintf("no action configured for route %s", route.Path.String()), - http.StatusInternalServerError) } + + v.aliasesExact.Add(a) + } + } + + return nil +} + +func (v *BaseVirtualHost) MatchesName(hostname string) bool { + if v.aliasesExact.Contains(hostname) { + return true + } + + for _, a := range v.aliasesCompiled { + if a.MatchString(hostname) { + return true } } - http.NotFound(w, r) + return false } func LoggerFromContext(ctx context.Context) log.Logger { @@ -379,6 +477,19 @@ func LoggerFromContext(ctx context.Context) log.Logger { return nil } +func AuthorizationFromContext(ctx context.Context) *Authorization { + a := log.StateFromContext(ctx).Get(kAuthorization) + if authorization, ok := a.(*Authorization); ok { + return authorization + } + + return nil +} + +func SetAuthorizationContext(ctx context.Context, authz *Authorization) { + log.StateFromContext(ctx).Set(kAuthorization, authz) +} + func newRequestID() string { buf := make([]byte, 10) _, _ = randSrc.Read(buf) diff --git a/http/sni_listener.go b/http/sni_listener.go new file mode 100644 index 0000000..a7d501b --- /dev/null +++ b/http/sni_listener.go @@ -0,0 +1,277 @@ +package http + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "net" + "time" + + "go.fuhry.dev/runtime/utils/log" +) + +type sniConn struct { + conn net.Conn + err error +} + +type sniVirtualHost struct { + tlsConfig *tls.Config + sniListener *SNIListener + matcher NameMatcher + connChan chan sniConn + logger log.Logger +} + +// SNIListener is a multiplexing listener that routes connections based on the +// Server Name Indication (SNI) property of incoming TLS handshakes. +// +// SNIListener must be initialized with an underlying net.Listener (usually +// created with net.ListenTCP). +type SNIListener struct { + listener net.Listener + vhosts []*sniVirtualHost + logger log.Logger +} + +func NewSNIListener(l net.Listener) *SNIListener { + logger := log.WithPrefix(fmt.Sprintf("sni://%s", l.Addr())) + listener := &SNIListener{ + listener: l, + logger: logger, + } + + return listener +} + +func (l *SNIListener) Serve() error { + for { + conn, err := l.listener.Accept() + if err != nil { + for _, v := range l.vhosts { + v.connChan <- sniConn{nil, err} + } + return err + } + + l.logger.V(2).Infof("new conn: %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + go l.handle(conn, l.logger.AppendPrefix(fmt.Sprintf("[%s]", conn.RemoteAddr()))) + } +} + +func (l *SNIListener) handle(conn net.Conn, logger log.Logger) { + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + logger.Warningf("SetReadDeadline: %v", err) + conn.Close() + return + } + + clientHello, clientReader, err := peekClientHello(conn) + if err != nil || clientHello.ServerName == "" { + logger.Warningf("peekClientHello: %v", err) + conn.Close() + return + } + + logger.V(2).Infof("ServerName: %s", clientHello.ServerName) + + if err = conn.SetReadDeadline(time.Time{}); err != nil { + logger.Warningf("SetReadDeadline: %v", err) + conn.Close() + return + } + + childConn := &compositeConn{ + reader: clientReader, + writer: conn, + } + + for _, vhost := range l.vhosts { + if vhost.matcher.MatchesName(clientHello.ServerName) { + vhost.connChan <- sniConn{childConn, nil} + return + } + } + + logger.Warningf("no servername match: %q", clientHello.ServerName) + conn.Close() +} + +func (l *SNIListener) AddVirtualHost(matcher NameMatcher, tlsConfig *tls.Config) (net.Listener, error) { + vhost := &sniVirtualHost{ + tlsConfig: tlsConfig, + sniListener: l, + matcher: matcher, + connChan: make(chan sniConn), + logger: l.logger, + } + + l.vhosts = append(l.vhosts, vhost) + + return vhost, nil +} + +func (l *SNIListener) Addr() net.Addr { + return l.listener.Addr() +} + +// Close shuts down the SNI listener and notifies all child listeners that the +// socket is closed. +// +// It is imperative that every child listener created with AddVirtualHost has +// an active Accept() call underway before Close() is called. Close() will +// block on writes to each child listener's connection channel. +func (l *SNIListener) Close() error { + return l.listener.Close() +} + +func (v *sniVirtualHost) Accept() (net.Conn, error) { + c := <-v.connChan + if c.conn != nil { + v.logger.V(3).Noticef("recv'd dispatched conn: %s <-> %s", c.conn.RemoteAddr(), c.conn.LocalAddr()) + } else if c.err != nil { + v.logger.V(3).Noticef("recv'd dispatched err: (%T) %v", c.err, c.err) + return nil, c.err + } + + if v.tlsConfig == nil { + return c.conn, nil + } + + tlsConn := tls.Server(c.conn, v.tlsConfig) + + if err := tlsConn.Handshake(); err != nil { + c.conn.Close() + return nil, err + } + + return tlsConn, nil +} + +func (v *sniVirtualHost) Addr() net.Addr { + return v.sniListener.Addr() +} + +func (v *sniVirtualHost) Close() error { + return v.sniListener.Close() +} + +// ref: https://www.agwa.name/blog/post/writing_an_sni_proxy_in_go +func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) { + var hello *tls.ClientHelloInfo + + err := tls.Server(&compositeConn{reader: reader}, &tls.Config{ + GetConfigForClient: func(incoming *tls.ClientHelloInfo) (*tls.Config, error) { + hello = new(tls.ClientHelloInfo) + *hello = *incoming + + return nil, nil + }, + }).Handshake() + + // we expect the handshake to fail, so only return err if hello doesn't get populated + if hello == nil { + return nil, err + } + + return hello, nil +} + +func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) { + buf := new(bytes.Buffer) + hello, err := readClientHello(io.TeeReader(reader, buf)) + if err != nil { + return nil, nil, err + } + return hello, io.MultiReader(buf, reader), nil +} + +type compositeConn struct { + reader io.Reader + writer io.Writer +} + +func (c *compositeConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +func (c *compositeConn) Write(p []byte) (int, error) { + if c.writer != nil { + return c.writer.Write(p) + } + return 0, io.ErrClosedPipe +} + +func (c *compositeConn) Close() error { + var err error + if rc, ok := c.reader.(io.ReadCloser); ok { + err = rc.Close() + } + if c.writer != nil { + if wc, ok := c.writer.(io.WriteCloser); ok { + err = wc.Close() + } + } + return err +} + +func (c *compositeConn) LocalAddr() net.Addr { + if conn, ok := c.reader.(net.Conn); ok { + return conn.LocalAddr() + } + if c.writer != nil { + if conn, ok := c.writer.(net.Conn); ok { + return conn.LocalAddr() + } + } + return nil +} + +func (c *compositeConn) RemoteAddr() net.Addr { + if conn, ok := c.reader.(net.Conn); ok { + return conn.RemoteAddr() + } + if c.writer != nil { + if conn, ok := c.writer.(net.Conn); ok { + return conn.RemoteAddr() + } + } + return nil +} + +func (c *compositeConn) SetDeadline(t time.Time) error { + if conn, ok := c.reader.(net.Conn); ok { + return conn.SetDeadline(t) + } + if c.writer != nil { + if conn, ok := c.writer.(net.Conn); ok { + return conn.SetDeadline(t) + } + } + return nil +} + +func (c *compositeConn) SetReadDeadline(t time.Time) error { + if conn, ok := c.reader.(net.Conn); ok { + return conn.SetReadDeadline(t) + } + if c.writer != nil { + if conn, ok := c.writer.(net.Conn); ok { + return conn.SetReadDeadline(t) + } + } + return nil +} + +func (c *compositeConn) SetWriteDeadline(t time.Time) error { + if conn, ok := c.reader.(net.Conn); ok { + return conn.SetWriteDeadline(t) + } + if c.writer != nil { + if conn, ok := c.writer.(net.Conn); ok { + return conn.SetWriteDeadline(t) + } + } + return nil +} diff --git a/mtls/fsnotify/fsnotify.go b/mtls/fsnotify/fsnotify.go index 3cc71b0..c42ed84 100644 --- a/mtls/fsnotify/fsnotify.go +++ b/mtls/fsnotify/fsnotify.go @@ -278,7 +278,7 @@ func handleEvent(event fsnotify.Event) { return } } - logger.V(1).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op) + logger.V(2).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op) } // handleEventForPath is the second stage of event handling which calls the actual handlers. diff --git a/utils/log/http.go b/utils/log/http.go index fe8a2ff..7ae8e7f 100644 --- a/utils/log/http.go +++ b/utils/log/http.go @@ -2,6 +2,7 @@ package log import ( "bufio" + "context" "encoding/json" "net" "net/http" @@ -29,9 +30,31 @@ type statusRecorder struct { BytesWritten uint } +type ExtraFunc func(map[string]any, http.ResponseWriter, *http.Request) +type mutableContext map[any]any +type ctxVar int + +const ( + kLoggerCtx ctxVar = iota +) + +// LoggingMiddleware is http middleware that writes access logs in JSON format to +// the specified logger. +// +// To use LoggingMiddleware, call NewLoggingMiddleware with the http.Handler it +// wraps, and set the Server or ServeMux handler function to the return value of +// LoggingMiddleware.HandlerFunc(). +// +// Each request's context has a mutable state map added to it, allowing arbitrary +// data to be passed between the inner handler and the logger. type LoggingMiddleware struct { - Logger coreLogger - ExtraFunc func(map[string]any) + // Logger is the logger to which access logs are written. + Logger coreLogger + // ExtraFunc is called when the log entry is being generated, allowing arbitrary + // values to be added to the log entry. + // + // Implementations MUST NOT mutate the provided http.ResponseWriter. + ExtraFunc ExtraFunc extraRequestHeaders []string extraResponseHeaders []string @@ -42,7 +65,7 @@ func NewLoggingMiddleware(h http.Handler) *LoggingMiddleware { return NewLoggingMiddlewareWithLogger(h, Default()) } -func NewLoggingMiddlewareWithLogger(h http.Handler, logger Logger) *LoggingMiddleware { +func NewLoggingMiddlewareWithLogger(h http.Handler, logger coreLogger) *LoggingMiddleware { lm := &LoggingMiddleware{ Logger: logger, h: h, @@ -78,6 +101,7 @@ func (lm *LoggingMiddleware) handle(w http.ResponseWriter, r *http.Request) { ws.Hijacker = h } + r = r.WithContext(context.WithValue(r.Context(), kLoggerCtx, make(mutableContext))) startTime := time.Now().UnixMilli() lm.h.ServeHTTP(ws, r) respTime := time.Now().UnixMilli() - startTime @@ -101,7 +125,7 @@ func (lm *LoggingMiddleware) handle(w http.ResponseWriter, r *http.Request) { entry[hKey] = w.Header().Get(h) } if lm.ExtraFunc != nil { - lm.ExtraFunc(entry) + lm.ExtraFunc(entry, w, r) } entryJson, err := json.Marshal(entry) @@ -110,6 +134,21 @@ func (lm *LoggingMiddleware) handle(w http.ResponseWriter, r *http.Request) { } } +func StateFromContext(ctx context.Context) mutableContext { + return ctx.Value(kLoggerCtx).(mutableContext) +} + +func (m mutableContext) Set(k, v any) { + m[k] = v +} + +func (m mutableContext) Get(k any) any { + if v, ok := m[k]; ok { + return v + } + return nil +} + func (r *statusRecorder) WriteHeader(status int) { r.ResponseWriter.WriteHeader(status) r.Status = status diff --git a/utils/log/log.go b/utils/log/log.go index 21a4c03..3ca0098 100644 --- a/utils/log/log.go +++ b/utils/log/log.go @@ -164,11 +164,14 @@ func (l *internalLogger) Print(v ...any) { } func (l *internalLogger) Printf(fmtstr string, v ...any) { + args := v if l.prefix != "" { - fmtstr = "[" + l.prefix + "] " + fmtstr + fmtstr = "[%s] " + fmtstr + args = []any{l.prefix} + args = append(args, v...) } - l.Logger.Printf(l.level.prefix(l.Writer())+fmtstr, v...) + l.Logger.Printf(l.level.prefix(l.Writer())+fmtstr, args...) } func (l *internalLogger) Println(v ...any) { -- 2.50.1