From 3bf0b0f20310252fddb8ea6404f20b3ce7bbbbfb Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Sun, 23 Mar 2025 00:13:16 -0400 Subject: [PATCH] [http/samlproxy] support route actions - just redirects for now --- http/samlproxy.go | 123 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 102 insertions(+), 21 deletions(-) diff --git a/http/samlproxy.go b/http/samlproxy.go index da87141..489f5d2 100644 --- a/http/samlproxy.go +++ b/http/samlproxy.go @@ -7,7 +7,6 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" - "encoding/json" "errors" "fmt" "io" @@ -35,9 +34,19 @@ import ( type authEnforcement uint +type RouteAction interface { + Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) +} + +type RedirectAction struct { + StatusCode int + Destination *url.URL +} + type Route struct { - Auth authEnforcement - Path stringmatch.StringMatcher + Auth authEnforcement + Path stringmatch.StringMatcher + Action RouteAction } type SAMLBackend struct { @@ -94,29 +103,42 @@ const ( var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`) var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"}) -// UnmarshalJSON implements json.Unmarshaler -func (ae *authEnforcement) UnmarshalJSON(token []byte) error { - var s string - if err := json.Unmarshal(token, &s); err != nil { - return errors.New("cannot unmarshal authEnforcement to string") +// Handle implements RouteAction +func (a *RedirectAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + newUrl := *r.URL + + if a.Destination.Host != "" { + newUrl.Host = a.Destination.Host } - switch s { - case "required": - *ae = AuthRequired - case "optional": - *ae = AuthOptional - default: - return fmt.Errorf("invalid auth enforcement string value: %s", s) + newUrl.Path = a.Destination.Path + if a.Destination.RawQuery != "" { + newUrl.RawQuery = a.Destination.RawQuery + } + if a.Destination.Scheme != "" { + newUrl.Scheme = a.Destination.Scheme + } + if a.Destination.Fragment != "" { + newUrl.Fragment = a.Destination.Fragment } - return nil + status := a.StatusCode + if status == 0 { + status = http.StatusFound + } + + w.Header().Set("location", newUrl.String()) + w.WriteHeader(status) } // UnmarshalYAML implements yaml.Unmarshaler func (r *Route) UnmarshalYAML(node *yaml.Node) error { var rawNode struct { - Auth string `yaml:"auth"` - Path *stringmatch.MatchRule `yaml:"path"` + Auth string `yaml:"auth"` + Path *stringmatch.MatchRule `yaml:"path"` + Redirect *struct { + Destination string `yaml:"dest"` + Status int `yaml:"status"` + } `yaml:"redirect"` } if err := node.Decode(&rawNode); err != nil { @@ -142,6 +164,17 @@ func (r *Route) UnmarshalYAML(node *yaml.Node) error { return errors.New("error unmarshaling route: exactly one of (`path`) must be specified") } + if rawNode.Redirect != nil { + u, err := url.Parse(rawNode.Redirect.Destination) + if err != nil { + return err + } + r.Action = &RedirectAction{ + Destination: u, + StatusCode: rawNode.Redirect.Status, + } + } + return nil } @@ -486,6 +519,8 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) { } defaultRoute := true + next := sp.fulfill(vhost, session) + sp.logger.V(3).Debugf("checking for routes matching %s", r.URL) for _, route := range vhost.Routes { match := false @@ -506,6 +541,11 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) { provider.HandleStartAuthFlow(w, r) return } + + if route.Action != nil { + sp.logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action) + route.Action.Handle(w, r, next) + } } } @@ -518,6 +558,14 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) { } } + next(w, r) + } + + return handle, nil +} + +func (sp *SAMLProxy) fulfill(vhost *SAMLVirtualHost, session samlsp.Session) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { if session != nil { sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session) } else { @@ -587,12 +635,45 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) { for name, value := range response.Header { w.Header().Set(name, strings.Join(value, ", ")) } - w.WriteHeader(response.StatusCode) + if response.StatusCode == http.StatusSwitchingProtocols { + hijacker, ok := w.(http.Hijacker) + if !ok { + sp.writeError(w, http.StatusMethodNotAllowed, errors.New("websocket passthrough not supported")) + return + } + + upstreamWriter, ok := response.Body.(io.Writer) + if !ok { + sp.writeError(w, http.StatusMethodNotAllowed, errors.New("body doesn't support io.Writer")) + return + } + + w.WriteHeader(response.StatusCode) + + conn, rw, err := hijacker.Hijack() + if err != nil { + sp.writeError(w, http.StatusInternalServerError, err) + return + } + + wg := sync.WaitGroup{} + wg.Add(2) + pipe := func(w io.Writer, r io.Reader) { + defer wg.Done() + io.Copy(w, r) + } + go pipe(rw, response.Body) + go pipe(upstreamWriter, rw) + + wg.Wait() + conn.Close() + return + } + + w.WriteHeader(response.StatusCode) io.Copy(w, response.Body) } - - return handle, nil } func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) { -- 2.50.1