]> go.fuhry.dev Git - runtime.git/commitdiff
move HTTP logging middleware to log package
authorDan Fuhry <dan@fuhry.com>
Mon, 30 Dec 2024 19:14:50 +0000 (14:14 -0500)
committerDan Fuhry <dan@fuhry.com>
Mon, 30 Dec 2024 19:14:50 +0000 (14:14 -0500)
HTTP access logs are needed in more places, so it makes sense to make this code reusable.

sase/ws_proxy.go
utils/log/http.go [new file with mode: 0644]

index f0c456663f36dcfdcb098fc9708cdf0a6c39cf8b..461ce51b9ebf1c48c806fea0c9296906f753695e 100644 (file)
@@ -1,9 +1,7 @@
 package sase
 
 import (
-       "bufio"
        "context"
-       "encoding/json"
        "fmt"
        "net"
        "net/http"
@@ -28,22 +26,6 @@ type WebSocketRequest struct {
        addr     []net.Addr
 }
 
-type logEntry struct {
-       RemoteAddress,
-       Host,
-       Method,
-       Path string
-
-       StatusCode int
-}
-
-type statusRecorder struct {
-       http.ResponseWriter
-       http.Hijacker
-
-       Status int
-}
-
 var (
        pathRegexp *regexp.Regexp
 )
@@ -59,9 +41,10 @@ func init() {
 
 func NewWebSocketProxy(listen string) (*WebSocketProxy, error) {
        handler := http.NewServeMux()
+       lm := log.NewLoggingMiddleware(handler)
        server := &http.Server{
                Addr:    listen,
-               Handler: newLoggingMiddleware(handler),
+               Handler: lm.HandlerFunc(),
        }
 
        wsp := &WebSocketProxy{
@@ -198,41 +181,3 @@ func (wsp *WebSocketProxy) replyBadRequest(w http.ResponseWriter, msg string) {
        w.WriteHeader(400)
        fmt.Fprintf(w, "400 Bad Request: %s\n", msg)
 }
-
-func newLoggingMiddleware(h http.Handler) http.Handler {
-       return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-               ws := &statusRecorder{
-                       ResponseWriter: w,
-               }
-               if h, ok := w.(http.Hijacker); ok {
-                       ws.Hijacker = h
-               }
-
-               h.ServeHTTP(ws, r)
-
-               logger := log.Default()
-
-               entry := logEntry{
-                       RemoteAddress: r.RemoteAddr,
-                       Host:          r.Host,
-                       Method:        r.Method,
-                       Path:          r.URL.Path,
-                       StatusCode:    ws.Status,
-               }
-
-               entryJson, err := json.Marshal(entry)
-               if err == nil {
-                       logger.Print(string(entryJson))
-               }
-       })
-}
-
-func (r *statusRecorder) WriteHeader(status int) {
-       r.ResponseWriter.WriteHeader(status)
-       r.Status = status
-}
-
-func (r *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-       r.Status = http.StatusSwitchingProtocols
-       return r.Hijacker.Hijack()
-}
diff --git a/utils/log/http.go b/utils/log/http.go
new file mode 100644 (file)
index 0000000..c212681
--- /dev/null
@@ -0,0 +1,111 @@
+package log
+
+import (
+       "bufio"
+       "encoding/json"
+       "net"
+       "net/http"
+       "net/textproto"
+       "strings"
+)
+
+const (
+       kRemoteAddress = "remote_address"
+       kHost          = "authority"
+       kMethod        = "method"
+       kPath          = "path"
+       kUserAgent     = "user_agent"
+       kStatusCode    = "status"
+)
+
+type statusRecorder struct {
+       http.ResponseWriter
+       http.Hijacker
+
+       Status int
+}
+
+type LoggingMiddleware struct {
+       Logger    Logger
+       ExtraFunc func(map[string]any)
+
+       extraRequestHeaders  []string
+       extraResponseHeaders []string
+       h                    http.Handler
+}
+
+func NewLoggingMiddleware(h http.Handler) *LoggingMiddleware {
+       return NewLoggingMiddlewareWithLogger(h, Default())
+}
+
+func NewLoggingMiddlewareWithLogger(h http.Handler, logger Logger) *LoggingMiddleware {
+       lm := &LoggingMiddleware{
+               Logger: logger,
+
+               extraRequestHeaders:  []string{"user-agent"},
+               extraResponseHeaders: []string{"content-type"},
+       }
+       return lm
+}
+
+func (lm *LoggingMiddleware) AddRequestHeader(header ...string) {
+       for _, h := range header {
+               lm.extraRequestHeaders = append(lm.extraRequestHeaders, textproto.CanonicalMIMEHeaderKey(h))
+       }
+}
+
+func (lm *LoggingMiddleware) AddResponseHeader(header ...string) {
+       for _, h := range header {
+               lm.extraResponseHeaders = append(lm.extraResponseHeaders, textproto.CanonicalMIMEHeaderKey(h))
+       }
+}
+
+func (lm *LoggingMiddleware) HandlerFunc() http.HandlerFunc {
+       return http.HandlerFunc(lm.handle)
+}
+
+func (lm *LoggingMiddleware) handle(w http.ResponseWriter, r *http.Request) {
+       ws := &statusRecorder{
+               ResponseWriter: w,
+       }
+       if h, ok := w.(http.Hijacker); ok {
+               ws.Hijacker = h
+       }
+
+       lm.h.ServeHTTP(ws, r)
+
+       entry := map[string]any{
+               kRemoteAddress: r.RemoteAddr,
+               kHost:          r.Host,
+               kMethod:        r.Method,
+               kPath:          r.URL.Path,
+               kStatusCode:    ws.Status,
+       }
+
+       for _, h := range lm.extraRequestHeaders {
+               hKey := strings.ReplaceAll("-", "_", h)
+               entry[hKey] = r.Header.Get(h)
+       }
+       for _, h := range lm.extraResponseHeaders {
+               hKey := strings.ReplaceAll("-", "_", h)
+               entry[hKey] = r.Header.Get(h)
+       }
+       if lm.ExtraFunc != nil {
+               lm.ExtraFunc(entry)
+       }
+
+       entryJson, err := json.Marshal(entry)
+       if err == nil {
+               lm.Logger.Print(string(entryJson))
+       }
+}
+
+func (r *statusRecorder) WriteHeader(status int) {
+       r.ResponseWriter.WriteHeader(status)
+       r.Status = status
+}
+
+func (r *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+       r.Status = http.StatusSwitchingProtocols
+       return r.Hijacker.Hijack()
+}