From 358df2ef4f5418a99973f0eb761bbd348a41830e Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Mon, 30 Dec 2024 14:14:50 -0500 Subject: [PATCH] move HTTP logging middleware to log package HTTP access logs are needed in more places, so it makes sense to make this code reusable. --- sase/ws_proxy.go | 59 +----------------------- utils/log/http.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 57 deletions(-) create mode 100644 utils/log/http.go diff --git a/sase/ws_proxy.go b/sase/ws_proxy.go index f0c4566..461ce51 100644 --- a/sase/ws_proxy.go +++ b/sase/ws_proxy.go @@ -1,9 +1,7 @@ package sase import ( - "bufio" "context" - "encoding/json" "fmt" "net" "net/http" @@ -28,22 +26,6 @@ type WebSocketRequest struct { addr []net.Addr } -type logEntry struct { - RemoteAddress, - Host, - Method, - Path string - - StatusCode int -} - -type statusRecorder struct { - http.ResponseWriter - http.Hijacker - - Status int -} - var ( pathRegexp *regexp.Regexp ) @@ -59,9 +41,10 @@ func init() { func NewWebSocketProxy(listen string) (*WebSocketProxy, error) { handler := http.NewServeMux() + lm := log.NewLoggingMiddleware(handler) server := &http.Server{ Addr: listen, - Handler: newLoggingMiddleware(handler), + Handler: lm.HandlerFunc(), } wsp := &WebSocketProxy{ @@ -198,41 +181,3 @@ func (wsp *WebSocketProxy) replyBadRequest(w http.ResponseWriter, msg string) { w.WriteHeader(400) fmt.Fprintf(w, "400 Bad Request: %s\n", msg) } - -func newLoggingMiddleware(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ws := &statusRecorder{ - ResponseWriter: w, - } - if h, ok := w.(http.Hijacker); ok { - ws.Hijacker = h - } - - h.ServeHTTP(ws, r) - - logger := log.Default() - - entry := logEntry{ - RemoteAddress: r.RemoteAddr, - Host: r.Host, - Method: r.Method, - Path: r.URL.Path, - StatusCode: ws.Status, - } - - entryJson, err := json.Marshal(entry) - if err == nil { - logger.Print(string(entryJson)) - } - }) -} - -func (r *statusRecorder) WriteHeader(status int) { - r.ResponseWriter.WriteHeader(status) - r.Status = status -} - -func (r *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { - r.Status = http.StatusSwitchingProtocols - return r.Hijacker.Hijack() -} diff --git a/utils/log/http.go b/utils/log/http.go new file mode 100644 index 0000000..c212681 --- /dev/null +++ b/utils/log/http.go @@ -0,0 +1,111 @@ +package log + +import ( + "bufio" + "encoding/json" + "net" + "net/http" + "net/textproto" + "strings" +) + +const ( + kRemoteAddress = "remote_address" + kHost = "authority" + kMethod = "method" + kPath = "path" + kUserAgent = "user_agent" + kStatusCode = "status" +) + +type statusRecorder struct { + http.ResponseWriter + http.Hijacker + + Status int +} + +type LoggingMiddleware struct { + Logger Logger + ExtraFunc func(map[string]any) + + extraRequestHeaders []string + extraResponseHeaders []string + h http.Handler +} + +func NewLoggingMiddleware(h http.Handler) *LoggingMiddleware { + return NewLoggingMiddlewareWithLogger(h, Default()) +} + +func NewLoggingMiddlewareWithLogger(h http.Handler, logger Logger) *LoggingMiddleware { + lm := &LoggingMiddleware{ + Logger: logger, + + extraRequestHeaders: []string{"user-agent"}, + extraResponseHeaders: []string{"content-type"}, + } + return lm +} + +func (lm *LoggingMiddleware) AddRequestHeader(header ...string) { + for _, h := range header { + lm.extraRequestHeaders = append(lm.extraRequestHeaders, textproto.CanonicalMIMEHeaderKey(h)) + } +} + +func (lm *LoggingMiddleware) AddResponseHeader(header ...string) { + for _, h := range header { + lm.extraResponseHeaders = append(lm.extraResponseHeaders, textproto.CanonicalMIMEHeaderKey(h)) + } +} + +func (lm *LoggingMiddleware) HandlerFunc() http.HandlerFunc { + return http.HandlerFunc(lm.handle) +} + +func (lm *LoggingMiddleware) handle(w http.ResponseWriter, r *http.Request) { + ws := &statusRecorder{ + ResponseWriter: w, + } + if h, ok := w.(http.Hijacker); ok { + ws.Hijacker = h + } + + lm.h.ServeHTTP(ws, r) + + entry := map[string]any{ + kRemoteAddress: r.RemoteAddr, + kHost: r.Host, + kMethod: r.Method, + kPath: r.URL.Path, + kStatusCode: ws.Status, + } + + for _, h := range lm.extraRequestHeaders { + hKey := strings.ReplaceAll("-", "_", h) + entry[hKey] = r.Header.Get(h) + } + for _, h := range lm.extraResponseHeaders { + hKey := strings.ReplaceAll("-", "_", h) + entry[hKey] = r.Header.Get(h) + } + if lm.ExtraFunc != nil { + lm.ExtraFunc(entry) + } + + entryJson, err := json.Marshal(entry) + if err == nil { + lm.Logger.Print(string(entryJson)) + } +} + +func (r *statusRecorder) WriteHeader(status int) { + r.ResponseWriter.WriteHeader(status) + r.Status = status +} + +func (r *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + r.Status = http.StatusSwitchingProtocols + return r.Hijacker.Hijack() +} -- 2.50.1