]> go.fuhry.dev Git - runtime.git/commitdiff
[http/samlproxy] support route actions - just redirects for now
authorDan Fuhry <dan@fuhry.com>
Sun, 23 Mar 2025 04:13:16 +0000 (00:13 -0400)
committerDan Fuhry <dan@fuhry.com>
Sun, 23 Mar 2025 04:13:16 +0000 (00:13 -0400)
http/samlproxy.go

index da87141f2fc69aa5a5d2cb260fc5ba589f21345f..489f5d23cd1153e0280fa1d5c3c96428ae26f21d 100644 (file)
@@ -7,7 +7,6 @@ import (
        "crypto/tls"
        "crypto/x509"
        "crypto/x509/pkix"
-       "encoding/json"
        "errors"
        "fmt"
        "io"
@@ -35,9 +34,19 @@ import (
 
 type authEnforcement uint
 
+type RouteAction interface {
+       Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc)
+}
+
+type RedirectAction struct {
+       StatusCode  int
+       Destination *url.URL
+}
+
 type Route struct {
-       Auth authEnforcement
-       Path stringmatch.StringMatcher
+       Auth   authEnforcement
+       Path   stringmatch.StringMatcher
+       Action RouteAction
 }
 
 type SAMLBackend struct {
@@ -94,29 +103,42 @@ const (
 var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`)
 var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"})
 
-// UnmarshalJSON implements json.Unmarshaler
-func (ae *authEnforcement) UnmarshalJSON(token []byte) error {
-       var s string
-       if err := json.Unmarshal(token, &s); err != nil {
-               return errors.New("cannot unmarshal authEnforcement to string")
+// Handle implements RouteAction
+func (a *RedirectAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
+       newUrl := *r.URL
+
+       if a.Destination.Host != "" {
+               newUrl.Host = a.Destination.Host
        }
-       switch s {
-       case "required":
-               *ae = AuthRequired
-       case "optional":
-               *ae = AuthOptional
-       default:
-               return fmt.Errorf("invalid auth enforcement string value: %s", s)
+       newUrl.Path = a.Destination.Path
+       if a.Destination.RawQuery != "" {
+               newUrl.RawQuery = a.Destination.RawQuery
+       }
+       if a.Destination.Scheme != "" {
+               newUrl.Scheme = a.Destination.Scheme
+       }
+       if a.Destination.Fragment != "" {
+               newUrl.Fragment = a.Destination.Fragment
        }
 
-       return nil
+       status := a.StatusCode
+       if status == 0 {
+               status = http.StatusFound
+       }
+
+       w.Header().Set("location", newUrl.String())
+       w.WriteHeader(status)
 }
 
 // UnmarshalYAML implements yaml.Unmarshaler
 func (r *Route) UnmarshalYAML(node *yaml.Node) error {
        var rawNode struct {
-               Auth string                 `yaml:"auth"`
-               Path *stringmatch.MatchRule `yaml:"path"`
+               Auth     string                 `yaml:"auth"`
+               Path     *stringmatch.MatchRule `yaml:"path"`
+               Redirect *struct {
+                       Destination string `yaml:"dest"`
+                       Status      int    `yaml:"status"`
+               } `yaml:"redirect"`
        }
 
        if err := node.Decode(&rawNode); err != nil {
@@ -142,6 +164,17 @@ func (r *Route) UnmarshalYAML(node *yaml.Node) error {
                return errors.New("error unmarshaling route: exactly one of (`path`) must be specified")
        }
 
+       if rawNode.Redirect != nil {
+               u, err := url.Parse(rawNode.Redirect.Destination)
+               if err != nil {
+                       return err
+               }
+               r.Action = &RedirectAction{
+                       Destination: u,
+                       StatusCode:  rawNode.Redirect.Status,
+               }
+       }
+
        return nil
 }
 
@@ -486,6 +519,8 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) {
                }
 
                defaultRoute := true
+               next := sp.fulfill(vhost, session)
+
                sp.logger.V(3).Debugf("checking for routes matching %s", r.URL)
                for _, route := range vhost.Routes {
                        match := false
@@ -506,6 +541,11 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) {
                                        provider.HandleStartAuthFlow(w, r)
                                        return
                                }
+
+                               if route.Action != nil {
+                                       sp.logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action)
+                                       route.Action.Handle(w, r, next)
+                               }
                        }
                }
 
@@ -518,6 +558,14 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) {
                        }
                }
 
+               next(w, r)
+       }
+
+       return handle, nil
+}
+
+func (sp *SAMLProxy) fulfill(vhost *SAMLVirtualHost, session samlsp.Session) http.HandlerFunc {
+       return func(w http.ResponseWriter, r *http.Request) {
                if session != nil {
                        sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session)
                } else {
@@ -587,12 +635,45 @@ func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) {
                for name, value := range response.Header {
                        w.Header().Set(name, strings.Join(value, ", "))
                }
-               w.WriteHeader(response.StatusCode)
 
+               if response.StatusCode == http.StatusSwitchingProtocols {
+                       hijacker, ok := w.(http.Hijacker)
+                       if !ok {
+                               sp.writeError(w, http.StatusMethodNotAllowed, errors.New("websocket passthrough not supported"))
+                               return
+                       }
+
+                       upstreamWriter, ok := response.Body.(io.Writer)
+                       if !ok {
+                               sp.writeError(w, http.StatusMethodNotAllowed, errors.New("body doesn't support io.Writer"))
+                               return
+                       }
+
+                       w.WriteHeader(response.StatusCode)
+
+                       conn, rw, err := hijacker.Hijack()
+                       if err != nil {
+                               sp.writeError(w, http.StatusInternalServerError, err)
+                               return
+                       }
+
+                       wg := sync.WaitGroup{}
+                       wg.Add(2)
+                       pipe := func(w io.Writer, r io.Reader) {
+                               defer wg.Done()
+                               io.Copy(w, r)
+                       }
+                       go pipe(rw, response.Body)
+                       go pipe(upstreamWriter, rw)
+
+                       wg.Wait()
+                       conn.Close()
+                       return
+               }
+
+               w.WriteHeader(response.StatusCode)
                io.Copy(w, response.Body)
        }
-
-       return handle, nil
 }
 
 func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) {