--- /dev/null
+package http
+
+import (
+ "context"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "regexp"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/crewjam/saml"
+ "github.com/crewjam/saml/samlsp"
+ "go.fuhry.dev/runtime/mtls"
+ "go.fuhry.dev/runtime/mtls/certutil"
+ "go.fuhry.dev/runtime/utils/hashset"
+ "go.fuhry.dev/runtime/utils/log"
+ "go.fuhry.dev/runtime/utils/stringmatch"
+ "gopkg.in/yaml.v3"
+)
+
+type authEnforcement uint
+
+const (
+ AuthRequired authEnforcement = iota
+ AuthOptional
+)
+
+var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`)
+var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"})
+
+// UnmarshalJSON implements json.Unmarshaler
+func (ae *authEnforcement) UnmarshalJSON(token []byte) error {
+ var s string
+ if err := json.Unmarshal(token, &s); err != nil {
+ return errors.New("cannot unmarshal authEnforcement to string")
+ }
+ switch s {
+ case "required":
+ *ae = AuthRequired
+ case "optional":
+ *ae = AuthOptional
+ default:
+ return fmt.Errorf("invalid auth enforcement string value: %s", s)
+ }
+
+ return nil
+}
+
+// UnmarshalYAML implements yaml.Unmarshaler
+func (ae *authEnforcement) UnmarshalYAML(node *yaml.Node) error {
+ if node.Kind != yaml.ScalarNode {
+ return errors.New("yaml field of type authEnforcement must be a string")
+ }
+
+ switch node.Value {
+ case "required":
+ *ae = AuthRequired
+ case "optional":
+ *ae = AuthOptional
+ default:
+ return fmt.Errorf("invalid auth enforcement string value: %s", node.Value)
+ }
+
+ return nil
+}
+
+type Route struct {
+ Auth authEnforcement
+ Path stringmatch.StringMatcher
+}
+
+type SAMLListener struct {
+ EntityID string `yaml:"entity_id"`
+ EntityCertificate string `yaml:"entity_certificate"`
+ EntityPrivateKey string `yaml:"entity_key"`
+ IDP string `yaml:"idp"`
+ Certificate string `yaml:"cert"`
+ Routes []Route `yaml:"routes"`
+}
+
+type SAMLBackend struct {
+ Host string `yaml:"host"`
+ Port int `yaml:"port"`
+ Identity string `yaml:"mtls_id"`
+}
+
+type SAMLProxy struct {
+ Listener SAMLListener `yaml:"listener"`
+ Backend SAMLBackend `yaml:"backend"`
+
+ logger *log.Logger
+ entityCert *x509.Certificate
+ entityKey *rsa.PrivateKey
+}
+
+func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
+ if sp.logger == nil {
+ sp.logger = log.WithPrefix("SAMLProxy")
+ }
+
+ idpMetadataUrl, err := url.Parse(sp.Listener.IDP)
+ if err != nil {
+ return nil, err
+ }
+ idpMetadata, err := samlsp.FetchMetadata(ctx, http.DefaultClient, *idpMetadataUrl)
+ if err != nil {
+ return nil, err
+ }
+
+ handler, err := sp.newHandler(ctx, idpMetadata)
+ if err != nil {
+ return nil, err
+ }
+
+ if sp.Listener.EntityPrivateKey != "" {
+ pvk, err := certutil.LoadPrivateKeyFromPEM(sp.Listener.EntityPrivateKey)
+ if err != nil {
+ return nil, err
+ }
+ rsaKey, ok := pvk.(*rsa.PrivateKey)
+ if !ok {
+ return nil, fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
+ }
+ sp.entityKey = rsaKey
+ } else {
+ // generate new RSA private key
+ sp.entityKey, err = rsa.GenerateKey(saml.RandReader, 2048)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if sp.Listener.EntityCertificate != "" {
+ certs, err := certutil.LoadCertificatesFromPEM(sp.Listener.EntityCertificate)
+ if err != nil {
+ return nil, err
+ }
+ sp.entityCert = certs[0]
+ } else {
+ // generate new self-signed X509 certificate
+ serialBytes := make([]byte, 16)
+ saml.RandReader.Read(serialBytes)
+ serial := big.NewInt(0)
+ serial.SetBytes(serialBytes)
+
+ template := &x509.Certificate{
+ Subject: pkix.Name{
+ CommonName: sp.Listener.EntityID,
+ },
+ SerialNumber: serial,
+ BasicConstraintsValid: true,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
+ IsCA: false,
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(90 * 86400 * time.Second),
+ }
+ certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &sp.entityKey.PublicKey, sp.entityKey)
+ if err != nil {
+ return nil, err
+ }
+ cert, err := x509.ParseCertificate(certBytes)
+ if err != nil {
+ return nil, err
+ }
+ sp.entityCert = cert
+ }
+
+ server := &http.Server{
+ Handler: handler,
+ }
+
+ if sp.Listener.Certificate != "" {
+ cert := mtls.NewSSLCertificate(sp.Listener.Certificate)
+ tlsConfig, err := cert.TlsConfig(ctx)
+ if err != nil {
+ return nil, err
+ }
+ // verif := mtls.NewPeerNameVerifier()
+ // verif.AllowFrom(mtls.Service, "devicetrust")
+ // verif.ConfigureServer(tlsConfig)
+ // tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+ // tlsConfig.ClientCAs = x509.NewCertPool()
+ // ca, err := certutil.LoadCertificatesFromPEM("/etc/ssl/mtls/rootca.pem")
+ // if err == nil {
+ // for _, cert := range ca {
+ // tlsConfig.ClientCAs.AddCert(cert)
+ // }
+ // }
+ server.TLSConfig = tlsConfig
+ }
+
+ return server, nil
+}
+
+func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDescriptor) (http.HandlerFunc, error) {
+ transport := &http.Transport{}
+ samlSp := make(map[string]*samlsp.Middleware, 0)
+ spMu := &sync.Mutex{}
+ if sp.Backend.Identity != "" {
+ myIdentity := mtls.DefaultIdentity()
+ tlsConfig, err := myIdentity.TlsConfig(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ verifier := mtls.NewPeerNameVerifier()
+ verifier.AllowFrom(mtls.Service, sp.Backend.Identity)
+ err = verifier.ConfigureClient(tlsConfig)
+ if err != nil {
+ return nil, err
+ }
+
+ transport.TLSClientConfig = tlsConfig
+ }
+ client := &http.Client{
+ Transport: transport,
+ }
+
+ handle := func(w http.ResponseWriter, r *http.Request) {
+ // ensure host header present
+ host := r.Header.Get("Host")
+ if host == "" {
+ host = r.Header.Get(":authority")
+ }
+ if host == "" {
+ host = r.Host
+ }
+ if host == "" {
+ r.Header.Write(os.Stderr)
+ fmt.Fprintf(os.Stderr, "%v\n", r.URL.String())
+ sp.writeError(w, http.StatusBadRequest, errors.New("missing Host header"))
+ return
+ }
+
+ // ensure client isn't trying to inject saml-related headers
+ if err := sp.checkRequest(r); err != nil {
+ sp.writeError(w, http.StatusBadRequest, err)
+ return
+ }
+
+ // make sure the browser isn't trying to access the IdP - this can happen if the TLS session
+ // was reused because our certificate is also valid for the SSO URL.
+ for _, ssoDesc := range idpMetadata.IDPSSODescriptors {
+ for _, ssoSvc := range ssoDesc.SingleSignOnServices {
+ if loginUrl, err := url.Parse(ssoSvc.Location); err == nil {
+ if loginUrl.Host == host {
+ sp.writeError(w, http.StatusMisdirectedRequest,
+ errors.New("Misdirected request: this is not the IDP you're looking for"))
+
+ return
+ }
+ }
+ }
+ }
+ // ensure we have SP instance
+ spMu.Lock()
+ if _, ok := samlSp[host]; !ok {
+ provider, err := samlsp.New(samlsp.Options{
+ EntityID: sp.Listener.EntityID,
+ URL: url.URL{
+ Scheme: "https",
+ Host: host,
+ },
+ Key: sp.entityKey,
+ Certificate: sp.entityCert,
+ IDPMetadata: idpMetadata,
+ })
+ if err != nil {
+ sp.writeError(w, http.StatusInternalServerError, err)
+ return
+ }
+
+ samlSp[host] = provider
+ }
+ spMu.Unlock()
+
+ provider := samlSp[host]
+ if r.URL.Path == "/saml/acs" {
+ provider.ServeACS(w, r)
+ return
+ }
+
+ session, sessionErr := provider.Session.GetSession(r)
+
+ if sessionErr != nil && sessionErr != samlsp.ErrNoSession {
+ sp.logger.V(2).Warningf("non-NoSession err from sp: %v", sessionErr)
+ sp.writeError(w, http.StatusBadRequest, sessionErr)
+ return
+ }
+
+ defaultRoute := true
+ sp.logger.V(3).Debugf("checking for routes matching %s", r.URL)
+ for _, route := range sp.Listener.Routes {
+ match := false
+ if route.Path != nil {
+ match = route.Path.Match(r.URL.Path)
+ sp.logger.V(3).Debugf("path %s matches %s: %t",
+ r.URL.Path, route.Path.String(), match)
+ } else {
+ sp.writeError(w, http.StatusInternalServerError,
+ errors.New("nothing to match on in route"))
+ }
+
+ if match {
+ defaultRoute = false
+ if sessionErr == samlsp.ErrNoSession && route.Auth == AuthRequired {
+ sp.logger.V(3).Debugf("route requires a valid session, redirecting")
+
+ provider.HandleStartAuthFlow(w, r)
+ return
+ }
+ }
+ }
+
+ if defaultRoute {
+ sp.logger.V(3).Debugf("using default route")
+ if sessionErr == samlsp.ErrNoSession {
+ sp.logger.V(3).Debugf("default route requires a valid session, redirecting")
+ provider.HandleStartAuthFlow(w, r)
+ return
+ }
+ }
+
+ if session != nil {
+ sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session)
+ } else {
+ sp.logger.V(3).Debugf("serving path %s without session", r.URL.Path)
+ }
+
+ newReq := r.Clone(r.Context())
+ newReq.URL.Scheme = "http"
+ if sp.Backend.Identity != "" {
+ newReq.URL.Scheme = "https"
+ }
+ newReq.URL.Host = net.JoinHostPort(sp.Backend.Host, strconv.Itoa(sp.Backend.Port))
+ newReq.RequestURI = ""
+
+ if swa, ok := session.(samlsp.SessionWithAttributes); ok {
+ attrs := swa.GetAttributes()
+ sp.logger.V(3).Debugf("setting origin request header: on-behalf-of: %q", attrs.Get("uid"))
+ newReq.Header.Set("on-behalf-of", attrs.Get("uid"))
+ }
+
+ if jwts, ok := session.(samlsp.JWTSessionClaims); ok {
+ newReq.Header.Set("x-saml-audience", jwts.StandardClaims.Audience)
+ sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-audience", jwts.StandardClaims.Audience)
+
+ iat := strconv.FormatInt(jwts.StandardClaims.IssuedAt, 10)
+ newReq.Header.Set("x-saml-issued-at", iat)
+ sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-issued-at", iat)
+
+ eat := strconv.FormatInt(jwts.StandardClaims.ExpiresAt, 10)
+ newReq.Header.Set("x-saml-expires-at", eat)
+ sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-expires-at", eat)
+
+ newReq.Header.Set("x-saml-subject", jwts.StandardClaims.Subject)
+ sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject)
+
+ for attr, values := range jwts.Attributes {
+ headerName := fmt.Sprintf("x-saml-%s",
+ samlAttributeReplaceRegexp.ReplaceAllString(strings.ToLower(attr), "-"))
+ headerValue := strings.Join(values, ", ")
+ sp.logger.V(3).Debugf("setting origin request header: %s: %s",
+ headerName, headerValue)
+ newReq.Header.Set(headerName, headerValue)
+ }
+ }
+
+ // set proxy headers
+ if remoteHost, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
+ sp.logger.V(3).Debugf("x-forwarded-for: %s", remoteHost)
+ newReq.Header.Set("x-forwarded-for", remoteHost)
+ }
+
+ // proxy the request to the backend
+ response, err := client.Do(newReq)
+ if err != nil {
+ sp.writeError(w, http.StatusBadGateway, err)
+ return
+ }
+
+ for name, value := range response.Header {
+ w.Header().Set(name, strings.Join(value, ", "))
+ }
+ w.WriteHeader(response.StatusCode)
+
+ io.Copy(w, response.Body)
+ }
+
+ return handle, nil
+}
+
+func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) {
+ sp.logger.V(1).Warningf("returning status: %d %s", status, err.Error())
+
+ w.WriteHeader(status)
+ w.Write([]byte(fmt.Sprintf("<h1>%d %s</h1>", status, err.Error())))
+}
+
+func (sp *SAMLProxy) checkRequest(r *http.Request) error {
+ for k, _ := range r.Header {
+ k = strings.ToLower(k)
+ if strings.HasPrefix(k, "x-saml-") || restrictedHeaders.Contains(k) {
+ return fmt.Errorf("downstream attempted to overwrite restricted header: %q", k)
+ }
+ }
+ return nil
+}