package sase
import (
- "bufio"
"context"
- "encoding/json"
"fmt"
"net"
"net/http"
addr []net.Addr
}
-type logEntry struct {
- RemoteAddress,
- Host,
- Method,
- Path string
-
- StatusCode int
-}
-
-type statusRecorder struct {
- http.ResponseWriter
- http.Hijacker
-
- Status int
-}
-
var (
pathRegexp *regexp.Regexp
)
func NewWebSocketProxy(listen string) (*WebSocketProxy, error) {
handler := http.NewServeMux()
+ lm := log.NewLoggingMiddleware(handler)
server := &http.Server{
Addr: listen,
- Handler: newLoggingMiddleware(handler),
+ Handler: lm.HandlerFunc(),
}
wsp := &WebSocketProxy{
w.WriteHeader(400)
fmt.Fprintf(w, "400 Bad Request: %s\n", msg)
}
-
-func newLoggingMiddleware(h http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ws := &statusRecorder{
- ResponseWriter: w,
- }
- if h, ok := w.(http.Hijacker); ok {
- ws.Hijacker = h
- }
-
- h.ServeHTTP(ws, r)
-
- logger := log.Default()
-
- entry := logEntry{
- RemoteAddress: r.RemoteAddr,
- Host: r.Host,
- Method: r.Method,
- Path: r.URL.Path,
- StatusCode: ws.Status,
- }
-
- entryJson, err := json.Marshal(entry)
- if err == nil {
- logger.Print(string(entryJson))
- }
- })
-}
-
-func (r *statusRecorder) WriteHeader(status int) {
- r.ResponseWriter.WriteHeader(status)
- r.Status = status
-}
-
-func (r *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- r.Status = http.StatusSwitchingProtocols
- return r.Hijacker.Hijack()
-}
--- /dev/null
+package log
+
+import (
+ "bufio"
+ "encoding/json"
+ "net"
+ "net/http"
+ "net/textproto"
+ "strings"
+)
+
+const (
+ kRemoteAddress = "remote_address"
+ kHost = "authority"
+ kMethod = "method"
+ kPath = "path"
+ kUserAgent = "user_agent"
+ kStatusCode = "status"
+)
+
+type statusRecorder struct {
+ http.ResponseWriter
+ http.Hijacker
+
+ Status int
+}
+
+type LoggingMiddleware struct {
+ Logger Logger
+ ExtraFunc func(map[string]any)
+
+ extraRequestHeaders []string
+ extraResponseHeaders []string
+ h http.Handler
+}
+
+func NewLoggingMiddleware(h http.Handler) *LoggingMiddleware {
+ return NewLoggingMiddlewareWithLogger(h, Default())
+}
+
+func NewLoggingMiddlewareWithLogger(h http.Handler, logger Logger) *LoggingMiddleware {
+ lm := &LoggingMiddleware{
+ Logger: logger,
+
+ extraRequestHeaders: []string{"user-agent"},
+ extraResponseHeaders: []string{"content-type"},
+ }
+ return lm
+}
+
+func (lm *LoggingMiddleware) AddRequestHeader(header ...string) {
+ for _, h := range header {
+ lm.extraRequestHeaders = append(lm.extraRequestHeaders, textproto.CanonicalMIMEHeaderKey(h))
+ }
+}
+
+func (lm *LoggingMiddleware) AddResponseHeader(header ...string) {
+ for _, h := range header {
+ lm.extraResponseHeaders = append(lm.extraResponseHeaders, textproto.CanonicalMIMEHeaderKey(h))
+ }
+}
+
+func (lm *LoggingMiddleware) HandlerFunc() http.HandlerFunc {
+ return http.HandlerFunc(lm.handle)
+}
+
+func (lm *LoggingMiddleware) handle(w http.ResponseWriter, r *http.Request) {
+ ws := &statusRecorder{
+ ResponseWriter: w,
+ }
+ if h, ok := w.(http.Hijacker); ok {
+ ws.Hijacker = h
+ }
+
+ lm.h.ServeHTTP(ws, r)
+
+ entry := map[string]any{
+ kRemoteAddress: r.RemoteAddr,
+ kHost: r.Host,
+ kMethod: r.Method,
+ kPath: r.URL.Path,
+ kStatusCode: ws.Status,
+ }
+
+ for _, h := range lm.extraRequestHeaders {
+ hKey := strings.ReplaceAll("-", "_", h)
+ entry[hKey] = r.Header.Get(h)
+ }
+ for _, h := range lm.extraResponseHeaders {
+ hKey := strings.ReplaceAll("-", "_", h)
+ entry[hKey] = r.Header.Get(h)
+ }
+ if lm.ExtraFunc != nil {
+ lm.ExtraFunc(entry)
+ }
+
+ entryJson, err := json.Marshal(entry)
+ if err == nil {
+ lm.Logger.Print(string(entryJson))
+ }
+}
+
+func (r *statusRecorder) WriteHeader(status int) {
+ r.ResponseWriter.WriteHeader(status)
+ r.Status = status
+}
+
+func (r *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ r.Status = http.StatusSwitchingProtocols
+ return r.Hijacker.Hijack()
+}