"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
- "encoding/json"
"errors"
"fmt"
"io"
type authEnforcement uint
+type RouteAction interface {
+ Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc)
+}
+
+type RedirectAction struct {
+ StatusCode int
+ Destination *url.URL
+}
+
type Route struct {
- Auth authEnforcement
- Path stringmatch.StringMatcher
+ Auth authEnforcement
+ Path stringmatch.StringMatcher
+ Action RouteAction
}
type SAMLBackend struct {
var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`)
var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"})
-// UnmarshalJSON implements json.Unmarshaler
-func (ae *authEnforcement) UnmarshalJSON(token []byte) error {
- var s string
- if err := json.Unmarshal(token, &s); err != nil {
- return errors.New("cannot unmarshal authEnforcement to string")
+// Handle implements RouteAction
+func (a *RedirectAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
+ newUrl := *r.URL
+
+ if a.Destination.Host != "" {
+ newUrl.Host = a.Destination.Host
}
- switch s {
- case "required":
- *ae = AuthRequired
- case "optional":
- *ae = AuthOptional
- default:
- return fmt.Errorf("invalid auth enforcement string value: %s", s)
+ newUrl.Path = a.Destination.Path
+ if a.Destination.RawQuery != "" {
+ newUrl.RawQuery = a.Destination.RawQuery
+ }
+ if a.Destination.Scheme != "" {
+ newUrl.Scheme = a.Destination.Scheme
+ }
+ if a.Destination.Fragment != "" {
+ newUrl.Fragment = a.Destination.Fragment
}
- return nil
+ status := a.StatusCode
+ if status == 0 {
+ status = http.StatusFound
+ }
+
+ w.Header().Set("location", newUrl.String())
+ w.WriteHeader(status)
}
// UnmarshalYAML implements yaml.Unmarshaler
func (r *Route) UnmarshalYAML(node *yaml.Node) error {
var rawNode struct {
- Auth string `yaml:"auth"`
- Path *stringmatch.MatchRule `yaml:"path"`
+ Auth string `yaml:"auth"`
+ Path *stringmatch.MatchRule `yaml:"path"`
+ Redirect *struct {
+ Destination string `yaml:"dest"`
+ Status int `yaml:"status"`
+ } `yaml:"redirect"`
}
if err := node.Decode(&rawNode); err != nil {
return errors.New("error unmarshaling route: exactly one of (`path`) must be specified")
}
+ if rawNode.Redirect != nil {
+ u, err := url.Parse(rawNode.Redirect.Destination)
+ if err != nil {
+ return err
+ }
+ r.Action = &RedirectAction{
+ Destination: u,
+ StatusCode: rawNode.Redirect.Status,
+ }
+ }
+
return nil
}
}
defaultRoute := true
+ next := sp.fulfill(vhost, session)
+
sp.logger.V(3).Debugf("checking for routes matching %s", r.URL)
for _, route := range vhost.Routes {
match := false
provider.HandleStartAuthFlow(w, r)
return
}
+
+ if route.Action != nil {
+ sp.logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action)
+ route.Action.Handle(w, r, next)
+ }
}
}
}
}
+ next(w, r)
+ }
+
+ return handle, nil
+}
+
+func (sp *SAMLProxy) fulfill(vhost *SAMLVirtualHost, session samlsp.Session) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
if session != nil {
sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session)
} else {
for name, value := range response.Header {
w.Header().Set(name, strings.Join(value, ", "))
}
- w.WriteHeader(response.StatusCode)
+ if response.StatusCode == http.StatusSwitchingProtocols {
+ hijacker, ok := w.(http.Hijacker)
+ if !ok {
+ sp.writeError(w, http.StatusMethodNotAllowed, errors.New("websocket passthrough not supported"))
+ return
+ }
+
+ upstreamWriter, ok := response.Body.(io.Writer)
+ if !ok {
+ sp.writeError(w, http.StatusMethodNotAllowed, errors.New("body doesn't support io.Writer"))
+ return
+ }
+
+ w.WriteHeader(response.StatusCode)
+
+ conn, rw, err := hijacker.Hijack()
+ if err != nil {
+ sp.writeError(w, http.StatusInternalServerError, err)
+ return
+ }
+
+ wg := sync.WaitGroup{}
+ wg.Add(2)
+ pipe := func(w io.Writer, r io.Reader) {
+ defer wg.Done()
+ io.Copy(w, r)
+ }
+ go pipe(rw, response.Body)
+ go pipe(upstreamWriter, rw)
+
+ wg.Wait()
+ conn.Close()
+ return
+ }
+
+ w.WriteHeader(response.StatusCode)
io.Copy(w, response.Body)
}
-
- return handle, nil
}
func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) {