From f493ec81ee16a86ea2c620a056f2e10b558d8aa0 Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Fri, 21 Mar 2025 21:14:27 -0400 Subject: [PATCH] [http] add samlproxy Add a basic SAML-enforcing sidecar proxy for future use with various internal services. --- .gitignore | 1 + http/samlproxy.go | 420 ++++++++++++++++++++++++++++++++++++++++ http/samlproxy/Makefile | 14 ++ http/samlproxy/main.go | 50 +++++ 4 files changed, 485 insertions(+) create mode 100644 http/samlproxy.go create mode 100644 http/samlproxy/Makefile create mode 100644 http/samlproxy/main.go diff --git a/.gitignore b/.gitignore index 729a6e7..187165e 100644 --- a/.gitignore +++ b/.gitignore @@ -49,5 +49,6 @@ mtls/verify_tool/verify_tool ldap/health_exporter/health_exporter envoy/xds/envoy_xds/envoy_xds mtls/mtls_exporter/mtls_exporter +http/samlproxy/samlproxy /vendor/ diff --git a/http/samlproxy.go b/http/samlproxy.go new file mode 100644 index 0000000..1778ca0 --- /dev/null +++ b/http/samlproxy.go @@ -0,0 +1,420 @@ +package http + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/url" + "os" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" + "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/mtls/certutil" + "go.fuhry.dev/runtime/utils/hashset" + "go.fuhry.dev/runtime/utils/log" + "go.fuhry.dev/runtime/utils/stringmatch" + "gopkg.in/yaml.v3" +) + +type authEnforcement uint + +const ( + AuthRequired authEnforcement = iota + AuthOptional +) + +var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`) +var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"}) + +// UnmarshalJSON implements json.Unmarshaler +func (ae *authEnforcement) UnmarshalJSON(token []byte) error { + var s string + if err := json.Unmarshal(token, &s); err != nil { + return errors.New("cannot unmarshal authEnforcement to string") + } + switch s { + case "required": + *ae = AuthRequired + case "optional": + *ae = AuthOptional + default: + return fmt.Errorf("invalid auth enforcement string value: %s", s) + } + + return nil +} + +// UnmarshalYAML implements yaml.Unmarshaler +func (ae *authEnforcement) UnmarshalYAML(node *yaml.Node) error { + if node.Kind != yaml.ScalarNode { + return errors.New("yaml field of type authEnforcement must be a string") + } + + switch node.Value { + case "required": + *ae = AuthRequired + case "optional": + *ae = AuthOptional + default: + return fmt.Errorf("invalid auth enforcement string value: %s", node.Value) + } + + return nil +} + +type Route struct { + Auth authEnforcement + Path stringmatch.StringMatcher +} + +type SAMLListener struct { + EntityID string `yaml:"entity_id"` + EntityCertificate string `yaml:"entity_certificate"` + EntityPrivateKey string `yaml:"entity_key"` + IDP string `yaml:"idp"` + Certificate string `yaml:"cert"` + Routes []Route `yaml:"routes"` +} + +type SAMLBackend struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Identity string `yaml:"mtls_id"` +} + +type SAMLProxy struct { + Listener SAMLListener `yaml:"listener"` + Backend SAMLBackend `yaml:"backend"` + + logger *log.Logger + entityCert *x509.Certificate + entityKey *rsa.PrivateKey +} + +func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) { + if sp.logger == nil { + sp.logger = log.WithPrefix("SAMLProxy") + } + + idpMetadataUrl, err := url.Parse(sp.Listener.IDP) + if err != nil { + return nil, err + } + idpMetadata, err := samlsp.FetchMetadata(ctx, http.DefaultClient, *idpMetadataUrl) + if err != nil { + return nil, err + } + + handler, err := sp.newHandler(ctx, idpMetadata) + if err != nil { + return nil, err + } + + if sp.Listener.EntityPrivateKey != "" { + pvk, err := certutil.LoadPrivateKeyFromPEM(sp.Listener.EntityPrivateKey) + if err != nil { + return nil, err + } + rsaKey, ok := pvk.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk) + } + sp.entityKey = rsaKey + } else { + // generate new RSA private key + sp.entityKey, err = rsa.GenerateKey(saml.RandReader, 2048) + if err != nil { + return nil, err + } + } + if sp.Listener.EntityCertificate != "" { + certs, err := certutil.LoadCertificatesFromPEM(sp.Listener.EntityCertificate) + if err != nil { + return nil, err + } + sp.entityCert = certs[0] + } else { + // generate new self-signed X509 certificate + serialBytes := make([]byte, 16) + saml.RandReader.Read(serialBytes) + serial := big.NewInt(0) + serial.SetBytes(serialBytes) + + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: sp.Listener.EntityID, + }, + SerialNumber: serial, + BasicConstraintsValid: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, + IsCA: false, + NotBefore: time.Now(), + NotAfter: time.Now().Add(90 * 86400 * time.Second), + } + certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &sp.entityKey.PublicKey, sp.entityKey) + if err != nil { + return nil, err + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, err + } + sp.entityCert = cert + } + + server := &http.Server{ + Handler: handler, + } + + if sp.Listener.Certificate != "" { + cert := mtls.NewSSLCertificate(sp.Listener.Certificate) + tlsConfig, err := cert.TlsConfig(ctx) + if err != nil { + return nil, err + } + // verif := mtls.NewPeerNameVerifier() + // verif.AllowFrom(mtls.Service, "devicetrust") + // verif.ConfigureServer(tlsConfig) + // tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + // tlsConfig.ClientCAs = x509.NewCertPool() + // ca, err := certutil.LoadCertificatesFromPEM("/etc/ssl/mtls/rootca.pem") + // if err == nil { + // for _, cert := range ca { + // tlsConfig.ClientCAs.AddCert(cert) + // } + // } + server.TLSConfig = tlsConfig + } + + return server, nil +} + +func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDescriptor) (http.HandlerFunc, error) { + transport := &http.Transport{} + samlSp := make(map[string]*samlsp.Middleware, 0) + spMu := &sync.Mutex{} + if sp.Backend.Identity != "" { + myIdentity := mtls.DefaultIdentity() + tlsConfig, err := myIdentity.TlsConfig(ctx) + if err != nil { + return nil, err + } + + verifier := mtls.NewPeerNameVerifier() + verifier.AllowFrom(mtls.Service, sp.Backend.Identity) + err = verifier.ConfigureClient(tlsConfig) + if err != nil { + return nil, err + } + + transport.TLSClientConfig = tlsConfig + } + client := &http.Client{ + Transport: transport, + } + + handle := func(w http.ResponseWriter, r *http.Request) { + // ensure host header present + host := r.Header.Get("Host") + if host == "" { + host = r.Header.Get(":authority") + } + if host == "" { + host = r.Host + } + if host == "" { + r.Header.Write(os.Stderr) + fmt.Fprintf(os.Stderr, "%v\n", r.URL.String()) + sp.writeError(w, http.StatusBadRequest, errors.New("missing Host header")) + return + } + + // ensure client isn't trying to inject saml-related headers + if err := sp.checkRequest(r); err != nil { + sp.writeError(w, http.StatusBadRequest, err) + return + } + + // make sure the browser isn't trying to access the IdP - this can happen if the TLS session + // was reused because our certificate is also valid for the SSO URL. + for _, ssoDesc := range idpMetadata.IDPSSODescriptors { + for _, ssoSvc := range ssoDesc.SingleSignOnServices { + if loginUrl, err := url.Parse(ssoSvc.Location); err == nil { + if loginUrl.Host == host { + sp.writeError(w, http.StatusMisdirectedRequest, + errors.New("Misdirected request: this is not the IDP you're looking for")) + + return + } + } + } + } + // ensure we have SP instance + spMu.Lock() + if _, ok := samlSp[host]; !ok { + provider, err := samlsp.New(samlsp.Options{ + EntityID: sp.Listener.EntityID, + URL: url.URL{ + Scheme: "https", + Host: host, + }, + Key: sp.entityKey, + Certificate: sp.entityCert, + IDPMetadata: idpMetadata, + }) + if err != nil { + sp.writeError(w, http.StatusInternalServerError, err) + return + } + + samlSp[host] = provider + } + spMu.Unlock() + + provider := samlSp[host] + if r.URL.Path == "/saml/acs" { + provider.ServeACS(w, r) + return + } + + session, sessionErr := provider.Session.GetSession(r) + + if sessionErr != nil && sessionErr != samlsp.ErrNoSession { + sp.logger.V(2).Warningf("non-NoSession err from sp: %v", sessionErr) + sp.writeError(w, http.StatusBadRequest, sessionErr) + return + } + + defaultRoute := true + sp.logger.V(3).Debugf("checking for routes matching %s", r.URL) + for _, route := range sp.Listener.Routes { + match := false + if route.Path != nil { + match = route.Path.Match(r.URL.Path) + sp.logger.V(3).Debugf("path %s matches %s: %t", + r.URL.Path, route.Path.String(), match) + } else { + sp.writeError(w, http.StatusInternalServerError, + errors.New("nothing to match on in route")) + } + + if match { + defaultRoute = false + if sessionErr == samlsp.ErrNoSession && route.Auth == AuthRequired { + sp.logger.V(3).Debugf("route requires a valid session, redirecting") + + provider.HandleStartAuthFlow(w, r) + return + } + } + } + + if defaultRoute { + sp.logger.V(3).Debugf("using default route") + if sessionErr == samlsp.ErrNoSession { + sp.logger.V(3).Debugf("default route requires a valid session, redirecting") + provider.HandleStartAuthFlow(w, r) + return + } + } + + if session != nil { + sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session) + } else { + sp.logger.V(3).Debugf("serving path %s without session", r.URL.Path) + } + + newReq := r.Clone(r.Context()) + newReq.URL.Scheme = "http" + if sp.Backend.Identity != "" { + newReq.URL.Scheme = "https" + } + newReq.URL.Host = net.JoinHostPort(sp.Backend.Host, strconv.Itoa(sp.Backend.Port)) + newReq.RequestURI = "" + + if swa, ok := session.(samlsp.SessionWithAttributes); ok { + attrs := swa.GetAttributes() + sp.logger.V(3).Debugf("setting origin request header: on-behalf-of: %q", attrs.Get("uid")) + newReq.Header.Set("on-behalf-of", attrs.Get("uid")) + } + + if jwts, ok := session.(samlsp.JWTSessionClaims); ok { + newReq.Header.Set("x-saml-audience", jwts.StandardClaims.Audience) + sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-audience", jwts.StandardClaims.Audience) + + iat := strconv.FormatInt(jwts.StandardClaims.IssuedAt, 10) + newReq.Header.Set("x-saml-issued-at", iat) + sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-issued-at", iat) + + eat := strconv.FormatInt(jwts.StandardClaims.ExpiresAt, 10) + newReq.Header.Set("x-saml-expires-at", eat) + sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-expires-at", eat) + + newReq.Header.Set("x-saml-subject", jwts.StandardClaims.Subject) + sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject) + + for attr, values := range jwts.Attributes { + headerName := fmt.Sprintf("x-saml-%s", + samlAttributeReplaceRegexp.ReplaceAllString(strings.ToLower(attr), "-")) + headerValue := strings.Join(values, ", ") + sp.logger.V(3).Debugf("setting origin request header: %s: %s", + headerName, headerValue) + newReq.Header.Set(headerName, headerValue) + } + } + + // set proxy headers + if remoteHost, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + sp.logger.V(3).Debugf("x-forwarded-for: %s", remoteHost) + newReq.Header.Set("x-forwarded-for", remoteHost) + } + + // proxy the request to the backend + response, err := client.Do(newReq) + if err != nil { + sp.writeError(w, http.StatusBadGateway, err) + return + } + + for name, value := range response.Header { + w.Header().Set(name, strings.Join(value, ", ")) + } + w.WriteHeader(response.StatusCode) + + io.Copy(w, response.Body) + } + + return handle, nil +} + +func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) { + sp.logger.V(1).Warningf("returning status: %d %s", status, err.Error()) + + w.WriteHeader(status) + w.Write([]byte(fmt.Sprintf("

%d %s

", status, err.Error()))) +} + +func (sp *SAMLProxy) checkRequest(r *http.Request) error { + for k, _ := range r.Header { + k = strings.ToLower(k) + if strings.HasPrefix(k, "x-saml-") || restrictedHeaders.Contains(k) { + return fmt.Errorf("downstream attempted to overwrite restricted header: %q", k) + } + } + return nil +} diff --git a/http/samlproxy/Makefile b/http/samlproxy/Makefile new file mode 100644 index 0000000..bfab546 --- /dev/null +++ b/http/samlproxy/Makefile @@ -0,0 +1,14 @@ +GOSRC = $(wildcard *.go) +GOEXE = $(shell basename `pwd`) +GOBUILDFLAGS := -buildmode=pie -trimpath + +all: $(GOEXE) + +clean: + rm -fv $(GOEXE) + +.PHONY: all clean + +$(GOEXE): %: $(GOSRC) + go build $(GOBUILDFLAGS) -o $@ $< + diff --git a/http/samlproxy/main.go b/http/samlproxy/main.go new file mode 100644 index 0000000..d473e2c --- /dev/null +++ b/http/samlproxy/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "flag" + "os/signal" + "syscall" + "time" + + "go.fuhry.dev/runtime/http" + "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/utils/log" + "go.fuhry.dev/runtime/utils/stringmatch" +) + +func main() { + mtls.SetDefaultIdentity("authproxy") + + sp := http.SAMLProxy{} + flag.StringVar(&sp.Listener.EntityID, "saml.entity-id", "", "entity ID of SAML service provider") + flag.StringVar(&sp.Listener.IDP, "saml.idp.url", "", "URL to IdP metadata") + flag.StringVar(&sp.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private") + flag.StringVar(&sp.Backend.Host, "backend.host", "127.0.0.1", "backend host") + flag.IntVar(&sp.Backend.Port, "backend.port", 0, "backend port") + flag.StringVar(&sp.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend") + listen := flag.String("listen", "[::]:8443", "address for auth proxy to listen on") + + flag.Parse() + + sp.Listener.Routes = []http.Route{ + { + Auth: http.AuthOptional, + Path: stringmatch.Exact("/trickortreat.html"), + }, + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + server, err := sp.NewHTTPServerWithContext(ctx) + if err != nil { + log.Panic(err) + } + server.Addr = *listen + go server.ListenAndServeTLS("", "") + + <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + server.Shutdown(shutdownCtx) +} -- 2.50.1