]> go.fuhry.dev Git - runtime.git/commitdiff
[http] add samlproxy
authorDan Fuhry <dan@fuhry.com>
Sat, 22 Mar 2025 01:14:27 +0000 (21:14 -0400)
committerDan Fuhry <dan@fuhry.com>
Sat, 22 Mar 2025 01:14:27 +0000 (21:14 -0400)
Add a basic SAML-enforcing sidecar proxy for future use with various internal services.

.gitignore
http/samlproxy.go [new file with mode: 0644]
http/samlproxy/Makefile [new file with mode: 0644]
http/samlproxy/main.go [new file with mode: 0644]

index 729a6e77547c54dfbccdd2ffee8e8ce88aaa38c4..187165e7a70e5678e816121e1996583c51a490ff 100644 (file)
@@ -49,5 +49,6 @@ mtls/verify_tool/verify_tool
 ldap/health_exporter/health_exporter
 envoy/xds/envoy_xds/envoy_xds
 mtls/mtls_exporter/mtls_exporter
+http/samlproxy/samlproxy
 
 /vendor/
diff --git a/http/samlproxy.go b/http/samlproxy.go
new file mode 100644 (file)
index 0000000..1778ca0
--- /dev/null
@@ -0,0 +1,420 @@
+package http
+
+import (
+       "context"
+       "crypto/rsa"
+       "crypto/x509"
+       "crypto/x509/pkix"
+       "encoding/json"
+       "errors"
+       "fmt"
+       "io"
+       "math/big"
+       "net"
+       "net/http"
+       "net/url"
+       "os"
+       "regexp"
+       "strconv"
+       "strings"
+       "sync"
+       "time"
+
+       "github.com/crewjam/saml"
+       "github.com/crewjam/saml/samlsp"
+       "go.fuhry.dev/runtime/mtls"
+       "go.fuhry.dev/runtime/mtls/certutil"
+       "go.fuhry.dev/runtime/utils/hashset"
+       "go.fuhry.dev/runtime/utils/log"
+       "go.fuhry.dev/runtime/utils/stringmatch"
+       "gopkg.in/yaml.v3"
+)
+
+type authEnforcement uint
+
+const (
+       AuthRequired authEnforcement = iota
+       AuthOptional
+)
+
+var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`)
+var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"})
+
+// UnmarshalJSON implements json.Unmarshaler
+func (ae *authEnforcement) UnmarshalJSON(token []byte) error {
+       var s string
+       if err := json.Unmarshal(token, &s); err != nil {
+               return errors.New("cannot unmarshal authEnforcement to string")
+       }
+       switch s {
+       case "required":
+               *ae = AuthRequired
+       case "optional":
+               *ae = AuthOptional
+       default:
+               return fmt.Errorf("invalid auth enforcement string value: %s", s)
+       }
+
+       return nil
+}
+
+// UnmarshalYAML implements yaml.Unmarshaler
+func (ae *authEnforcement) UnmarshalYAML(node *yaml.Node) error {
+       if node.Kind != yaml.ScalarNode {
+               return errors.New("yaml field of type authEnforcement must be a string")
+       }
+
+       switch node.Value {
+       case "required":
+               *ae = AuthRequired
+       case "optional":
+               *ae = AuthOptional
+       default:
+               return fmt.Errorf("invalid auth enforcement string value: %s", node.Value)
+       }
+
+       return nil
+}
+
+type Route struct {
+       Auth authEnforcement
+       Path stringmatch.StringMatcher
+}
+
+type SAMLListener struct {
+       EntityID          string  `yaml:"entity_id"`
+       EntityCertificate string  `yaml:"entity_certificate"`
+       EntityPrivateKey  string  `yaml:"entity_key"`
+       IDP               string  `yaml:"idp"`
+       Certificate       string  `yaml:"cert"`
+       Routes            []Route `yaml:"routes"`
+}
+
+type SAMLBackend struct {
+       Host     string `yaml:"host"`
+       Port     int    `yaml:"port"`
+       Identity string `yaml:"mtls_id"`
+}
+
+type SAMLProxy struct {
+       Listener SAMLListener `yaml:"listener"`
+       Backend  SAMLBackend  `yaml:"backend"`
+
+       logger     *log.Logger
+       entityCert *x509.Certificate
+       entityKey  *rsa.PrivateKey
+}
+
+func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
+       if sp.logger == nil {
+               sp.logger = log.WithPrefix("SAMLProxy")
+       }
+
+       idpMetadataUrl, err := url.Parse(sp.Listener.IDP)
+       if err != nil {
+               return nil, err
+       }
+       idpMetadata, err := samlsp.FetchMetadata(ctx, http.DefaultClient, *idpMetadataUrl)
+       if err != nil {
+               return nil, err
+       }
+
+       handler, err := sp.newHandler(ctx, idpMetadata)
+       if err != nil {
+               return nil, err
+       }
+
+       if sp.Listener.EntityPrivateKey != "" {
+               pvk, err := certutil.LoadPrivateKeyFromPEM(sp.Listener.EntityPrivateKey)
+               if err != nil {
+                       return nil, err
+               }
+               rsaKey, ok := pvk.(*rsa.PrivateKey)
+               if !ok {
+                       return nil, fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
+               }
+               sp.entityKey = rsaKey
+       } else {
+               // generate new RSA private key
+               sp.entityKey, err = rsa.GenerateKey(saml.RandReader, 2048)
+               if err != nil {
+                       return nil, err
+               }
+       }
+       if sp.Listener.EntityCertificate != "" {
+               certs, err := certutil.LoadCertificatesFromPEM(sp.Listener.EntityCertificate)
+               if err != nil {
+                       return nil, err
+               }
+               sp.entityCert = certs[0]
+       } else {
+               // generate new self-signed X509 certificate
+               serialBytes := make([]byte, 16)
+               saml.RandReader.Read(serialBytes)
+               serial := big.NewInt(0)
+               serial.SetBytes(serialBytes)
+
+               template := &x509.Certificate{
+                       Subject: pkix.Name{
+                               CommonName: sp.Listener.EntityID,
+                       },
+                       SerialNumber:          serial,
+                       BasicConstraintsValid: true,
+                       ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
+                       KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
+                       IsCA:                  false,
+                       NotBefore:             time.Now(),
+                       NotAfter:              time.Now().Add(90 * 86400 * time.Second),
+               }
+               certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &sp.entityKey.PublicKey, sp.entityKey)
+               if err != nil {
+                       return nil, err
+               }
+               cert, err := x509.ParseCertificate(certBytes)
+               if err != nil {
+                       return nil, err
+               }
+               sp.entityCert = cert
+       }
+
+       server := &http.Server{
+               Handler: handler,
+       }
+
+       if sp.Listener.Certificate != "" {
+               cert := mtls.NewSSLCertificate(sp.Listener.Certificate)
+               tlsConfig, err := cert.TlsConfig(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               // verif := mtls.NewPeerNameVerifier()
+               // verif.AllowFrom(mtls.Service, "devicetrust")
+               // verif.ConfigureServer(tlsConfig)
+               // tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+               // tlsConfig.ClientCAs = x509.NewCertPool()
+               // ca, err := certutil.LoadCertificatesFromPEM("/etc/ssl/mtls/rootca.pem")
+               // if err == nil {
+               //      for _, cert := range ca {
+               //              tlsConfig.ClientCAs.AddCert(cert)
+               //      }
+               // }
+               server.TLSConfig = tlsConfig
+       }
+
+       return server, nil
+}
+
+func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDescriptor) (http.HandlerFunc, error) {
+       transport := &http.Transport{}
+       samlSp := make(map[string]*samlsp.Middleware, 0)
+       spMu := &sync.Mutex{}
+       if sp.Backend.Identity != "" {
+               myIdentity := mtls.DefaultIdentity()
+               tlsConfig, err := myIdentity.TlsConfig(ctx)
+               if err != nil {
+                       return nil, err
+               }
+
+               verifier := mtls.NewPeerNameVerifier()
+               verifier.AllowFrom(mtls.Service, sp.Backend.Identity)
+               err = verifier.ConfigureClient(tlsConfig)
+               if err != nil {
+                       return nil, err
+               }
+
+               transport.TLSClientConfig = tlsConfig
+       }
+       client := &http.Client{
+               Transport: transport,
+       }
+
+       handle := func(w http.ResponseWriter, r *http.Request) {
+               // ensure host header present
+               host := r.Header.Get("Host")
+               if host == "" {
+                       host = r.Header.Get(":authority")
+               }
+               if host == "" {
+                       host = r.Host
+               }
+               if host == "" {
+                       r.Header.Write(os.Stderr)
+                       fmt.Fprintf(os.Stderr, "%v\n", r.URL.String())
+                       sp.writeError(w, http.StatusBadRequest, errors.New("missing Host header"))
+                       return
+               }
+
+               // ensure client isn't trying to inject saml-related headers
+               if err := sp.checkRequest(r); err != nil {
+                       sp.writeError(w, http.StatusBadRequest, err)
+                       return
+               }
+
+               // make sure the browser isn't trying to access the IdP - this can happen if the TLS session
+               // was reused because our certificate is also valid for the SSO URL.
+               for _, ssoDesc := range idpMetadata.IDPSSODescriptors {
+                       for _, ssoSvc := range ssoDesc.SingleSignOnServices {
+                               if loginUrl, err := url.Parse(ssoSvc.Location); err == nil {
+                                       if loginUrl.Host == host {
+                                               sp.writeError(w, http.StatusMisdirectedRequest,
+                                                       errors.New("Misdirected request: this is not the IDP you're looking for"))
+
+                                               return
+                                       }
+                               }
+                       }
+               }
+               // ensure we have SP instance
+               spMu.Lock()
+               if _, ok := samlSp[host]; !ok {
+                       provider, err := samlsp.New(samlsp.Options{
+                               EntityID: sp.Listener.EntityID,
+                               URL: url.URL{
+                                       Scheme: "https",
+                                       Host:   host,
+                               },
+                               Key:         sp.entityKey,
+                               Certificate: sp.entityCert,
+                               IDPMetadata: idpMetadata,
+                       })
+                       if err != nil {
+                               sp.writeError(w, http.StatusInternalServerError, err)
+                               return
+                       }
+
+                       samlSp[host] = provider
+               }
+               spMu.Unlock()
+
+               provider := samlSp[host]
+               if r.URL.Path == "/saml/acs" {
+                       provider.ServeACS(w, r)
+                       return
+               }
+
+               session, sessionErr := provider.Session.GetSession(r)
+
+               if sessionErr != nil && sessionErr != samlsp.ErrNoSession {
+                       sp.logger.V(2).Warningf("non-NoSession err from sp: %v", sessionErr)
+                       sp.writeError(w, http.StatusBadRequest, sessionErr)
+                       return
+               }
+
+               defaultRoute := true
+               sp.logger.V(3).Debugf("checking for routes matching %s", r.URL)
+               for _, route := range sp.Listener.Routes {
+                       match := false
+                       if route.Path != nil {
+                               match = route.Path.Match(r.URL.Path)
+                               sp.logger.V(3).Debugf("path %s matches %s: %t",
+                                       r.URL.Path, route.Path.String(), match)
+                       } else {
+                               sp.writeError(w, http.StatusInternalServerError,
+                                       errors.New("nothing to match on in route"))
+                       }
+
+                       if match {
+                               defaultRoute = false
+                               if sessionErr == samlsp.ErrNoSession && route.Auth == AuthRequired {
+                                       sp.logger.V(3).Debugf("route requires a valid session, redirecting")
+
+                                       provider.HandleStartAuthFlow(w, r)
+                                       return
+                               }
+                       }
+               }
+
+               if defaultRoute {
+                       sp.logger.V(3).Debugf("using default route")
+                       if sessionErr == samlsp.ErrNoSession {
+                               sp.logger.V(3).Debugf("default route requires a valid session, redirecting")
+                               provider.HandleStartAuthFlow(w, r)
+                               return
+                       }
+               }
+
+               if session != nil {
+                       sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session)
+               } else {
+                       sp.logger.V(3).Debugf("serving path %s without session", r.URL.Path)
+               }
+
+               newReq := r.Clone(r.Context())
+               newReq.URL.Scheme = "http"
+               if sp.Backend.Identity != "" {
+                       newReq.URL.Scheme = "https"
+               }
+               newReq.URL.Host = net.JoinHostPort(sp.Backend.Host, strconv.Itoa(sp.Backend.Port))
+               newReq.RequestURI = ""
+
+               if swa, ok := session.(samlsp.SessionWithAttributes); ok {
+                       attrs := swa.GetAttributes()
+                       sp.logger.V(3).Debugf("setting origin request header: on-behalf-of: %q", attrs.Get("uid"))
+                       newReq.Header.Set("on-behalf-of", attrs.Get("uid"))
+               }
+
+               if jwts, ok := session.(samlsp.JWTSessionClaims); ok {
+                       newReq.Header.Set("x-saml-audience", jwts.StandardClaims.Audience)
+                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-audience", jwts.StandardClaims.Audience)
+
+                       iat := strconv.FormatInt(jwts.StandardClaims.IssuedAt, 10)
+                       newReq.Header.Set("x-saml-issued-at", iat)
+                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-issued-at", iat)
+
+                       eat := strconv.FormatInt(jwts.StandardClaims.ExpiresAt, 10)
+                       newReq.Header.Set("x-saml-expires-at", eat)
+                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-expires-at", eat)
+
+                       newReq.Header.Set("x-saml-subject", jwts.StandardClaims.Subject)
+                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject)
+
+                       for attr, values := range jwts.Attributes {
+                               headerName := fmt.Sprintf("x-saml-%s",
+                                       samlAttributeReplaceRegexp.ReplaceAllString(strings.ToLower(attr), "-"))
+                               headerValue := strings.Join(values, ", ")
+                               sp.logger.V(3).Debugf("setting origin request header: %s: %s",
+                                       headerName, headerValue)
+                               newReq.Header.Set(headerName, headerValue)
+                       }
+               }
+
+               // set proxy headers
+               if remoteHost, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
+                       sp.logger.V(3).Debugf("x-forwarded-for: %s", remoteHost)
+                       newReq.Header.Set("x-forwarded-for", remoteHost)
+               }
+
+               // proxy the request to the backend
+               response, err := client.Do(newReq)
+               if err != nil {
+                       sp.writeError(w, http.StatusBadGateway, err)
+                       return
+               }
+
+               for name, value := range response.Header {
+                       w.Header().Set(name, strings.Join(value, ", "))
+               }
+               w.WriteHeader(response.StatusCode)
+
+               io.Copy(w, response.Body)
+       }
+
+       return handle, nil
+}
+
+func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) {
+       sp.logger.V(1).Warningf("returning status: %d %s", status, err.Error())
+
+       w.WriteHeader(status)
+       w.Write([]byte(fmt.Sprintf("<h1>%d %s</h1>", status, err.Error())))
+}
+
+func (sp *SAMLProxy) checkRequest(r *http.Request) error {
+       for k, _ := range r.Header {
+               k = strings.ToLower(k)
+               if strings.HasPrefix(k, "x-saml-") || restrictedHeaders.Contains(k) {
+                       return fmt.Errorf("downstream attempted to overwrite restricted header: %q", k)
+               }
+       }
+       return nil
+}
diff --git a/http/samlproxy/Makefile b/http/samlproxy/Makefile
new file mode 100644 (file)
index 0000000..bfab546
--- /dev/null
@@ -0,0 +1,14 @@
+GOSRC = $(wildcard *.go)
+GOEXE = $(shell basename `pwd`)
+GOBUILDFLAGS := -buildmode=pie -trimpath
+
+all: $(GOEXE)
+
+clean:
+       rm -fv $(GOEXE)
+
+.PHONY: all clean
+
+$(GOEXE): %: $(GOSRC)
+       go build $(GOBUILDFLAGS) -o $@ $<
+
diff --git a/http/samlproxy/main.go b/http/samlproxy/main.go
new file mode 100644 (file)
index 0000000..d473e2c
--- /dev/null
@@ -0,0 +1,50 @@
+package main
+
+import (
+       "context"
+       "flag"
+       "os/signal"
+       "syscall"
+       "time"
+
+       "go.fuhry.dev/runtime/http"
+       "go.fuhry.dev/runtime/mtls"
+       "go.fuhry.dev/runtime/utils/log"
+       "go.fuhry.dev/runtime/utils/stringmatch"
+)
+
+func main() {
+       mtls.SetDefaultIdentity("authproxy")
+
+       sp := http.SAMLProxy{}
+       flag.StringVar(&sp.Listener.EntityID, "saml.entity-id", "", "entity ID of SAML service provider")
+       flag.StringVar(&sp.Listener.IDP, "saml.idp.url", "", "URL to IdP metadata")
+       flag.StringVar(&sp.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private")
+       flag.StringVar(&sp.Backend.Host, "backend.host", "127.0.0.1", "backend host")
+       flag.IntVar(&sp.Backend.Port, "backend.port", 0, "backend port")
+       flag.StringVar(&sp.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend")
+       listen := flag.String("listen", "[::]:8443", "address for auth proxy to listen on")
+
+       flag.Parse()
+
+       sp.Listener.Routes = []http.Route{
+               {
+                       Auth: http.AuthOptional,
+                       Path: stringmatch.Exact("/trickortreat.html"),
+               },
+       }
+
+       ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+       defer cancel()
+       server, err := sp.NewHTTPServerWithContext(ctx)
+       if err != nil {
+               log.Panic(err)
+       }
+       server.Addr = *listen
+       go server.ListenAndServeTLS("", "")
+
+       <-ctx.Done()
+       shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
+       defer shutdownCancel()
+       server.Shutdown(shutdownCtx)
+}