From: Dan Fuhry Date: Sun, 23 Mar 2025 01:30:11 +0000 (-0400) Subject: [http/samlproxy] multiple vhosts, loadable yaml config, SAML config per vhost X-Git-Url: https://go.fuhry.dev/?a=commitdiff_plain;h=d12699bfa826fd293cb1cfd65deb764061bea7d6;p=runtime.git [http/samlproxy] multiple vhosts, loadable yaml config, SAML config per vhost --- diff --git a/http/samlproxy.go b/http/samlproxy.go index 178dd5f..c39e6d3 100644 --- a/http/samlproxy.go +++ b/http/samlproxy.go @@ -2,7 +2,9 @@ package http import ( "context" + "crypto" "crypto/rsa" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/json" @@ -32,6 +34,54 @@ import ( type authEnforcement uint +type Route struct { + Auth authEnforcement + Path stringmatch.StringMatcher +} + +type SAMLBackend struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Identity string `yaml:"mtls_id"` + + client *http.Client + clientOnce sync.Once +} + +type SAMLVirtualHost struct { + *SAMLServiceProvider `yaml:"saml"` + + Backend *SAMLBackend `yaml:"backend"` + Routes []*Route `yaml:"routes"` +} + +type SAMLServiceProvider struct { + EntityID string `yaml:"entity_id"` + EntityCertificate string `yaml:"entity_certificate"` + EntityPrivateKey string `yaml:"entity_key"` + IDP string `yaml:"idp"` + + metadata *saml.EntityDescriptor + metadataOnce sync.Once + + entityCert *x509.Certificate + entityKey *rsa.PrivateKey + certKeyOnce sync.Once +} + +type SAMLListener struct { + *SAMLServiceProvider `yaml:"saml"` + + Certificate string `yaml:"cert"` + VirtualHosts map[string]*SAMLVirtualHost `yaml:"virtual_hosts"` +} + +type SAMLProxy struct { + Listener SAMLListener `yaml:"listener"` + + logger log.Logger +} + const ( AuthRequired authEnforcement = iota AuthOptional @@ -59,126 +109,134 @@ func (ae *authEnforcement) UnmarshalJSON(token []byte) error { } // UnmarshalYAML implements yaml.Unmarshaler -func (ae *authEnforcement) UnmarshalYAML(node *yaml.Node) error { - if node.Kind != yaml.ScalarNode { - return errors.New("yaml field of type authEnforcement must be a string") +func (r *Route) UnmarshalYAML(node *yaml.Node) error { + var rawNode struct { + Auth string `yaml:"auth"` + Path *stringmatch.MatchRule `yaml:"path"` } - switch node.Value { + if err := node.Decode(&rawNode); err != nil { + return err + } + + switch rawNode.Auth { case "required": - *ae = AuthRequired + r.Auth = AuthRequired case "optional": - *ae = AuthOptional + r.Auth = AuthOptional default: - return fmt.Errorf("invalid auth enforcement string value: %s", node.Value) + return fmt.Errorf("error unmarshaling route: invalid auth enforcement string value: %s", node.Value) + } + + if rawNode.Path != nil { + m, err := rawNode.Path.Matcher() + if err != nil { + return fmt.Errorf("error unmarshaling route: invalid path matcher: %v", err) + } + r.Path = m + } else { + return errors.New("error unmarshaling route: exactly one of (`path`) must be specified") } return nil } -type Route struct { - Auth authEnforcement - Path stringmatch.StringMatcher -} +// RouteFromArg implements the 3rd argument to flag.Func. +// +// It parses a string in the format of auth:field:match_mode:value, returning a Route if +// it parses successfully. +func RouteFromArg(arg string) (*Route, error) { + parts := strings.SplitN(arg, ":", 4) + if len(parts) != 4 { + return nil, fmt.Errorf("invalid route spec: %q", arg) + } + a, f, t, v := parts[0], parts[1], parts[2], parts[3] + var auth authEnforcement + switch strings.ToLower(a) { + case "r", "req", "required": + auth = AuthRequired + case "o", "opt", "optional": + auth = AuthOptional + default: + return nil, fmt.Errorf("invalid auth setting: %q", a) + } -type SAMLListener struct { - EntityID string `yaml:"entity_id"` - EntityCertificate string `yaml:"entity_certificate"` - EntityPrivateKey string `yaml:"entity_key"` - IDP string `yaml:"idp"` - Certificate string `yaml:"cert"` - Routes []Route `yaml:"routes"` -} + route := &Route{ + Auth: auth, + } -type SAMLBackend struct { - Host string `yaml:"host"` - Port int `yaml:"port"` - Identity string `yaml:"mtls_id"` -} + match := stringmatch.MatchRule{ + Mode: t, + Value: v, + } + m, err := match.Matcher() + if err != nil { + return nil, err + } -type SAMLProxy struct { - Listener SAMLListener `yaml:"listener"` - Backend SAMLBackend `yaml:"backend"` + switch strings.ToLower(f) { + case "p", "path": + route.Path = m + default: + return nil, fmt.Errorf("invalid match field: %q", f) + } - logger log.Logger - entityCert *x509.Certificate - entityKey *rsa.PrivateKey + return route, nil } -func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) { - if sp.logger == nil { - sp.logger = log.Default().WithPrefix("SAMLProxy") - } +// Client returns an HTTP client for making requests to the backend. +func (b *SAMLBackend) Client() (*http.Client, error) { + var err error + b.clientOnce.Do(func() { + transport := &http.Transport{} + var tlsConfig *tls.Config + + if b.Identity != "" { + myIdentity := mtls.DefaultIdentity() + tlsConfig, err = myIdentity.TlsConfig(context.Background()) + if err != nil { + return + } + + verifier := mtls.NewPeerNameVerifier() + verifier.AllowFrom(mtls.Service, b.Identity) + err = verifier.ConfigureClient(tlsConfig) + if err != nil { + return + } - idpMetadataUrl, err := url.Parse(sp.Listener.IDP) + transport.TLSClientConfig = tlsConfig + } + + client := &http.Client{ + Transport: transport, + } + + b.client = client + }) if err != nil { return nil, err } - idpMetadata, err := samlsp.FetchMetadata(ctx, http.DefaultClient, *idpMetadataUrl) - if err != nil { - return nil, err + return b.client, nil +} + +// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host +// and other settings. +func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) { + var _ yaml.Unmarshaler = &Route{} + + if sp.logger == nil { + sp.logger = log.Default().WithPrefix("SAMLProxy") } - handler, err := sp.newHandler(ctx, idpMetadata) + handler, err := sp.newHandler() if err != nil { return nil, err } - if sp.Listener.EntityPrivateKey != "" { - pvk, err := certutil.LoadPrivateKeyFromPEM(sp.Listener.EntityPrivateKey) - if err != nil { - return nil, err - } - rsaKey, ok := pvk.(*rsa.PrivateKey) - if !ok { - return nil, fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk) - } - sp.entityKey = rsaKey - } else { - // generate new RSA private key - sp.entityKey, err = rsa.GenerateKey(saml.RandReader, 2048) - if err != nil { - return nil, err - } - } - if sp.Listener.EntityCertificate != "" { - certs, err := certutil.LoadCertificatesFromPEM(sp.Listener.EntityCertificate) - if err != nil { - return nil, err - } - sp.entityCert = certs[0] - } else { - // generate new self-signed X509 certificate - serialBytes := make([]byte, 16) - saml.RandReader.Read(serialBytes) - serial := big.NewInt(0) - serial.SetBytes(serialBytes) - - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: sp.Listener.EntityID, - }, - SerialNumber: serial, - BasicConstraintsValid: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, - IsCA: false, - NotBefore: time.Now(), - NotAfter: time.Now().Add(90 * 86400 * time.Second), - } - certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &sp.entityKey.PublicKey, sp.entityKey) - if err != nil { - return nil, err - } - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - return nil, err - } - sp.entityCert = cert - } - + lm := log.NewLoggingMiddlewareWithLogger(handler, sp.logger) server := &http.Server{ - Handler: handler, + Handler: lm.HandlerFunc(), } if sp.Listener.Certificate != "" { @@ -187,46 +245,124 @@ func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server if err != nil { return nil, err } - // verif := mtls.NewPeerNameVerifier() - // verif.AllowFrom(mtls.Service, "devicetrust") - // verif.ConfigureServer(tlsConfig) - // tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - // tlsConfig.ClientCAs = x509.NewCertPool() - // ca, err := certutil.LoadCertificatesFromPEM("/etc/ssl/mtls/rootca.pem") - // if err == nil { - // for _, cert := range ca { - // tlsConfig.ClientCAs.AddCert(cert) - // } - // } server.TLSConfig = tlsConfig } return server, nil } -func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDescriptor) (http.HandlerFunc, error) { - transport := &http.Transport{} - samlSp := make(map[string]*samlsp.Middleware, 0) - spMu := &sync.Mutex{} - if sp.Backend.Identity != "" { - myIdentity := mtls.DefaultIdentity() - tlsConfig, err := myIdentity.TlsConfig(ctx) +func (sp *SAMLServiceProvider) Metadata() (*saml.EntityDescriptor, error) { + var err error + sp.metadataOnce.Do(func() { + var idpMetadataUrl *url.URL + idpMetadataUrl, err = url.Parse(sp.IDP) if err != nil { - return nil, err + return } - verifier := mtls.NewPeerNameVerifier() - verifier.AllowFrom(mtls.Service, sp.Backend.Identity) - err = verifier.ConfigureClient(tlsConfig) - if err != nil { - return nil, err + sp.metadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataUrl) + }) + + if err != nil { + sp.metadataOnce = sync.Once{} + return nil, err + } + + return sp.metadata, nil +} + +func (sp *SAMLServiceProvider) CertAndKey() (cert *x509.Certificate, pvk *rsa.PrivateKey, err error) { + sp.certKeyOnce.Do(func() { + if sp.EntityPrivateKey != "" { + var loadedKey crypto.PrivateKey + loadedKey, err = certutil.LoadPrivateKeyFromPEM(sp.EntityPrivateKey) + if err != nil { + return + } + var ok bool + pvk, ok = loadedKey.(*rsa.PrivateKey) + if !ok { + err = fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk) + return + } + } else { + // generate new RSA private key + pvk, err = rsa.GenerateKey(saml.RandReader, 2048) + if err != nil { + return + } + } + if sp.EntityCertificate != "" { + certs, err := certutil.LoadCertificatesFromPEM(sp.EntityCertificate) + if err != nil { + return + } + cert = certs[0] + } else { + // generate new self-signed X509 certificate + serialBytes := make([]byte, 16) + saml.RandReader.Read(serialBytes) + serial := big.NewInt(0) + serial.SetBytes(serialBytes) + + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: sp.EntityID, + }, + SerialNumber: serial, + BasicConstraintsValid: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, + IsCA: false, + NotBefore: time.Now(), + NotAfter: time.Now().Add(90 * 86400 * time.Second), + } + certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &pvk.PublicKey, pvk) + if err != nil { + return + } + cert, err = x509.ParseCertificate(certBytes) + if err != nil { + return + } } - transport.TLSClientConfig = tlsConfig + sp.entityCert = cert + sp.entityKey = pvk + }) + + if err != nil { + sp.metadataOnce = sync.Once{} + return nil, nil, err } - client := &http.Client{ - Transport: transport, + + return sp.entityCert, sp.entityKey, nil +} + +func (sp *SAMLServiceProvider) NewServiceProvider(host string) (*samlsp.Middleware, error) { + idpMetadata, err := sp.Metadata() + if err != nil { + return nil, err + } + cert, key, err := sp.CertAndKey() + if err != nil { + return nil, err } + return samlsp.New(samlsp.Options{ + EntityID: sp.EntityID, + URL: url.URL{ + Scheme: "https", + Host: host, + }, + Key: key, + Certificate: cert, + IDPMetadata: idpMetadata, + }) +} + +func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) { + samlSp := make(map[string]*samlsp.Middleware, 0) + spMu := &sync.Mutex{} handle := func(w http.ResponseWriter, r *http.Request) { // ensure host header present @@ -250,39 +386,47 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes return } - // make sure the browser isn't trying to access the IdP - this can happen if the TLS session - // was reused because our certificate is also valid for the SSO URL. - for _, ssoDesc := range idpMetadata.IDPSSODescriptors { - for _, ssoSvc := range ssoDesc.SingleSignOnServices { - if loginUrl, err := url.Parse(ssoSvc.Location); err == nil { - if loginUrl.Host == host { - sp.writeError(w, http.StatusMisdirectedRequest, - errors.New("Misdirected request: this is not the IDP you're looking for")) + // make sure this host is known + vhost, ok := sp.Listener.VirtualHosts[host] + if !ok { + sp.writeError(w, http.StatusMisdirectedRequest, + errors.New("Misdirected request: unknown virtual host")) - return - } - } - } + return } + // ensure we have SP instance spMu.Lock() if _, ok := samlSp[host]; !ok { - provider, err := samlsp.New(samlsp.Options{ - EntityID: sp.Listener.EntityID, - URL: url.URL{ - Scheme: "https", - Host: host, - }, - Key: sp.entityKey, - Certificate: sp.entityCert, - IDPMetadata: idpMetadata, - }) + samlSettings := vhost.SAMLServiceProvider + if samlSettings == nil { + samlSettings = sp.Listener.SAMLServiceProvider + } + idpMetadata, err := samlSettings.Metadata() if err != nil { sp.writeError(w, http.StatusInternalServerError, err) return } - - samlSp[host] = provider + // make sure the browser isn't trying to access the IdP - this can happen if the TLS session + // was reused because our certificate is also valid for the SSO URL. + for _, ssoDesc := range idpMetadata.IDPSSODescriptors { + for _, ssoSvc := range ssoDesc.SingleSignOnServices { + if loginUrl, err := url.Parse(ssoSvc.Location); err == nil { + if loginUrl.Host == host { + sp.writeError(w, http.StatusMisdirectedRequest, + errors.New("Misdirected request: this is not the IDP you're looking for")) + + return + } + } + } + } + middleware, err := samlSettings.NewServiceProvider(host) + if err != nil { + sp.writeError(w, http.StatusInternalServerError, err) + return + } + samlSp[host] = middleware } spMu.Unlock() @@ -302,7 +446,7 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes defaultRoute := true sp.logger.V(3).Debugf("checking for routes matching %s", r.URL) - for _, route := range sp.Listener.Routes { + for _, route := range vhost.Routes { match := false if route.Path != nil { match = route.Path.Match(r.URL.Path) @@ -341,10 +485,10 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes newReq := r.Clone(r.Context()) newReq.URL.Scheme = "http" - if sp.Backend.Identity != "" { + if vhost.Backend.Identity != "" { newReq.URL.Scheme = "https" } - newReq.URL.Host = net.JoinHostPort(sp.Backend.Host, strconv.Itoa(sp.Backend.Port)) + newReq.URL.Host = net.JoinHostPort(vhost.Backend.Host, strconv.Itoa(vhost.Backend.Port)) newReq.RequestURI = "" if swa, ok := session.(samlsp.SessionWithAttributes); ok { @@ -385,6 +529,10 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes } // proxy the request to the backend + client, err := vhost.Backend.Client() + if err != nil { + sp.writeError(w, http.StatusInternalServerError, fmt.Errorf("error setting up connection to backend: %v", err)) + } response, err := client.Do(newReq) if err != nil { sp.writeError(w, http.StatusBadGateway, err) diff --git a/http/samlproxy/main.go b/http/samlproxy/main.go index d473e2c..52be074 100644 --- a/http/samlproxy/main.go +++ b/http/samlproxy/main.go @@ -3,36 +3,66 @@ package main import ( "context" "flag" + "os" "os/signal" "syscall" "time" + "gopkg.in/yaml.v3" + "go.fuhry.dev/runtime/http" "go.fuhry.dev/runtime/mtls" "go.fuhry.dev/runtime/utils/log" - "go.fuhry.dev/runtime/utils/stringmatch" ) func main() { mtls.SetDefaultIdentity("authproxy") - sp := http.SAMLProxy{} + sp := &http.SAMLProxy{ + Listener: http.SAMLListener{ + SAMLServiceProvider: &http.SAMLServiceProvider{}, + }, + } + vhost := &http.SAMLVirtualHost{ + Backend: &http.SAMLBackend{}, + } + + loadConfig := func(arg string) error { + contents, err := os.ReadFile(arg) + if err != nil { + return err + } + + err = yaml.Unmarshal(contents, sp) + return err + } + addRoute := func(arg string) error { + route, err := http.RouteFromArg(arg) + if err != nil { + return err + } + vhost.Routes = append(vhost.Routes, route) + return nil + } + + vhostName := flag.String("vhost", "", "HTTP(S) hostname to serve") + flag.Func("config", "YAML file to load configuration from", loadConfig) + flag.Func("route", "Route rule in the format of auth:field:matcher:value\n"+ + " auth: required, optional\n"+ + " field: path\n"+ + " matcher: prefix, suffix, exact, contains, regexp\n"+ + " value: any string", addRoute) flag.StringVar(&sp.Listener.EntityID, "saml.entity-id", "", "entity ID of SAML service provider") flag.StringVar(&sp.Listener.IDP, "saml.idp.url", "", "URL to IdP metadata") flag.StringVar(&sp.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private") - flag.StringVar(&sp.Backend.Host, "backend.host", "127.0.0.1", "backend host") - flag.IntVar(&sp.Backend.Port, "backend.port", 0, "backend port") - flag.StringVar(&sp.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend") + flag.StringVar(&vhost.Backend.Host, "backend.host", "127.0.0.1", "backend host") + flag.IntVar(&vhost.Backend.Port, "backend.port", 0, "backend port") + flag.StringVar(&vhost.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend") listen := flag.String("listen", "[::]:8443", "address for auth proxy to listen on") flag.Parse() - sp.Listener.Routes = []http.Route{ - { - Auth: http.AuthOptional, - Path: stringmatch.Exact("/trickortreat.html"), - }, - } + sp.Listener.VirtualHosts[*vhostName] = vhost ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel()