]> go.fuhry.dev Git - runtime.git/commitdiff
[http/samlproxy] multiple vhosts, loadable yaml config, SAML config per vhost
authorDan Fuhry <dan@fuhry.com>
Sun, 23 Mar 2025 01:30:11 +0000 (21:30 -0400)
committerDan Fuhry <dan@fuhry.com>
Sun, 23 Mar 2025 01:30:11 +0000 (21:30 -0400)
http/samlproxy.go
http/samlproxy/main.go

index 178dd5fc42c2a933ed535c32e285841fd2f13d6d..c39e6d3ec5adf97598c96423cc9cac2417cfad70 100644 (file)
@@ -2,7 +2,9 @@ package http
 
 import (
        "context"
+       "crypto"
        "crypto/rsa"
+       "crypto/tls"
        "crypto/x509"
        "crypto/x509/pkix"
        "encoding/json"
@@ -32,6 +34,54 @@ import (
 
 type authEnforcement uint
 
+type Route struct {
+       Auth authEnforcement
+       Path stringmatch.StringMatcher
+}
+
+type SAMLBackend struct {
+       Host     string `yaml:"host"`
+       Port     int    `yaml:"port"`
+       Identity string `yaml:"mtls_id"`
+
+       client     *http.Client
+       clientOnce sync.Once
+}
+
+type SAMLVirtualHost struct {
+       *SAMLServiceProvider `yaml:"saml"`
+
+       Backend *SAMLBackend `yaml:"backend"`
+       Routes  []*Route     `yaml:"routes"`
+}
+
+type SAMLServiceProvider struct {
+       EntityID          string `yaml:"entity_id"`
+       EntityCertificate string `yaml:"entity_certificate"`
+       EntityPrivateKey  string `yaml:"entity_key"`
+       IDP               string `yaml:"idp"`
+
+       metadata     *saml.EntityDescriptor
+       metadataOnce sync.Once
+
+       entityCert  *x509.Certificate
+       entityKey   *rsa.PrivateKey
+       certKeyOnce sync.Once
+}
+
+type SAMLListener struct {
+       *SAMLServiceProvider `yaml:"saml"`
+
+       Certificate  string                      `yaml:"cert"`
+       VirtualHosts map[string]*SAMLVirtualHost `yaml:"virtual_hosts"`
+}
+
+type SAMLProxy struct {
+       Listener SAMLListener `yaml:"listener"`
+
+       logger log.Logger
+}
+
 const (
        AuthRequired authEnforcement = iota
        AuthOptional
@@ -59,126 +109,134 @@ func (ae *authEnforcement) UnmarshalJSON(token []byte) error {
 }
 
 // UnmarshalYAML implements yaml.Unmarshaler
-func (ae *authEnforcement) UnmarshalYAML(node *yaml.Node) error {
-       if node.Kind != yaml.ScalarNode {
-               return errors.New("yaml field of type authEnforcement must be a string")
+func (r *Route) UnmarshalYAML(node *yaml.Node) error {
+       var rawNode struct {
+               Auth string                 `yaml:"auth"`
+               Path *stringmatch.MatchRule `yaml:"path"`
        }
 
-       switch node.Value {
+       if err := node.Decode(&rawNode); err != nil {
+               return err
+       }
+
+       switch rawNode.Auth {
        case "required":
-               *ae = AuthRequired
+               r.Auth = AuthRequired
        case "optional":
-               *ae = AuthOptional
+               r.Auth = AuthOptional
        default:
-               return fmt.Errorf("invalid auth enforcement string value: %s", node.Value)
+               return fmt.Errorf("error unmarshaling route: invalid auth enforcement string value: %s", node.Value)
+       }
+
+       if rawNode.Path != nil {
+               m, err := rawNode.Path.Matcher()
+               if err != nil {
+                       return fmt.Errorf("error unmarshaling route: invalid path matcher: %v", err)
+               }
+               r.Path = m
+       } else {
+               return errors.New("error unmarshaling route: exactly one of (`path`) must be specified")
        }
 
        return nil
 }
 
-type Route struct {
-       Auth authEnforcement
-       Path stringmatch.StringMatcher
-}
+// RouteFromArg implements the 3rd argument to flag.Func.
+//
+// It parses a string in the format of auth:field:match_mode:value, returning a Route if
+// it parses successfully.
+func RouteFromArg(arg string) (*Route, error) {
+       parts := strings.SplitN(arg, ":", 4)
+       if len(parts) != 4 {
+               return nil, fmt.Errorf("invalid route spec: %q", arg)
+       }
+       a, f, t, v := parts[0], parts[1], parts[2], parts[3]
+       var auth authEnforcement
+       switch strings.ToLower(a) {
+       case "r", "req", "required":
+               auth = AuthRequired
+       case "o", "opt", "optional":
+               auth = AuthOptional
+       default:
+               return nil, fmt.Errorf("invalid auth setting: %q", a)
+       }
 
-type SAMLListener struct {
-       EntityID          string  `yaml:"entity_id"`
-       EntityCertificate string  `yaml:"entity_certificate"`
-       EntityPrivateKey  string  `yaml:"entity_key"`
-       IDP               string  `yaml:"idp"`
-       Certificate       string  `yaml:"cert"`
-       Routes            []Route `yaml:"routes"`
-}
+       route := &Route{
+               Auth: auth,
+       }
 
-type SAMLBackend struct {
-       Host     string `yaml:"host"`
-       Port     int    `yaml:"port"`
-       Identity string `yaml:"mtls_id"`
-}
+       match := stringmatch.MatchRule{
+               Mode:  t,
+               Value: v,
+       }
+       m, err := match.Matcher()
+       if err != nil {
+               return nil, err
+       }
 
-type SAMLProxy struct {
-       Listener SAMLListener `yaml:"listener"`
-       Backend  SAMLBackend  `yaml:"backend"`
+       switch strings.ToLower(f) {
+       case "p", "path":
+               route.Path = m
+       default:
+               return nil, fmt.Errorf("invalid match field: %q", f)
+       }
 
-       logger     log.Logger
-       entityCert *x509.Certificate
-       entityKey  *rsa.PrivateKey
+       return route, nil
 }
 
-func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
-       if sp.logger == nil {
-               sp.logger = log.Default().WithPrefix("SAMLProxy")
-       }
+// Client returns an HTTP client for making requests to the backend.
+func (b *SAMLBackend) Client() (*http.Client, error) {
+       var err error
+       b.clientOnce.Do(func() {
+               transport := &http.Transport{}
+               var tlsConfig *tls.Config
+
+               if b.Identity != "" {
+                       myIdentity := mtls.DefaultIdentity()
+                       tlsConfig, err = myIdentity.TlsConfig(context.Background())
+                       if err != nil {
+                               return
+                       }
+
+                       verifier := mtls.NewPeerNameVerifier()
+                       verifier.AllowFrom(mtls.Service, b.Identity)
+                       err = verifier.ConfigureClient(tlsConfig)
+                       if err != nil {
+                               return
+                       }
 
-       idpMetadataUrl, err := url.Parse(sp.Listener.IDP)
+                       transport.TLSClientConfig = tlsConfig
+               }
+
+               client := &http.Client{
+                       Transport: transport,
+               }
+
+               b.client = client
+       })
        if err != nil {
                return nil, err
        }
-       idpMetadata, err := samlsp.FetchMetadata(ctx, http.DefaultClient, *idpMetadataUrl)
-       if err != nil {
-               return nil, err
+       return b.client, nil
+}
+
+// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host
+// and other settings.
+func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
+       var _ yaml.Unmarshaler = &Route{}
+
+       if sp.logger == nil {
+               sp.logger = log.Default().WithPrefix("SAMLProxy")
        }
 
-       handler, err := sp.newHandler(ctx, idpMetadata)
+       handler, err := sp.newHandler()
        if err != nil {
                return nil, err
        }
 
-       if sp.Listener.EntityPrivateKey != "" {
-               pvk, err := certutil.LoadPrivateKeyFromPEM(sp.Listener.EntityPrivateKey)
-               if err != nil {
-                       return nil, err
-               }
-               rsaKey, ok := pvk.(*rsa.PrivateKey)
-               if !ok {
-                       return nil, fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
-               }
-               sp.entityKey = rsaKey
-       } else {
-               // generate new RSA private key
-               sp.entityKey, err = rsa.GenerateKey(saml.RandReader, 2048)
-               if err != nil {
-                       return nil, err
-               }
-       }
-       if sp.Listener.EntityCertificate != "" {
-               certs, err := certutil.LoadCertificatesFromPEM(sp.Listener.EntityCertificate)
-               if err != nil {
-                       return nil, err
-               }
-               sp.entityCert = certs[0]
-       } else {
-               // generate new self-signed X509 certificate
-               serialBytes := make([]byte, 16)
-               saml.RandReader.Read(serialBytes)
-               serial := big.NewInt(0)
-               serial.SetBytes(serialBytes)
-
-               template := &x509.Certificate{
-                       Subject: pkix.Name{
-                               CommonName: sp.Listener.EntityID,
-                       },
-                       SerialNumber:          serial,
-                       BasicConstraintsValid: true,
-                       ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
-                       KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
-                       IsCA:                  false,
-                       NotBefore:             time.Now(),
-                       NotAfter:              time.Now().Add(90 * 86400 * time.Second),
-               }
-               certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &sp.entityKey.PublicKey, sp.entityKey)
-               if err != nil {
-                       return nil, err
-               }
-               cert, err := x509.ParseCertificate(certBytes)
-               if err != nil {
-                       return nil, err
-               }
-               sp.entityCert = cert
-       }
-
+       lm := log.NewLoggingMiddlewareWithLogger(handler, sp.logger)
        server := &http.Server{
-               Handler: handler,
+               Handler: lm.HandlerFunc(),
        }
 
        if sp.Listener.Certificate != "" {
@@ -187,46 +245,124 @@ func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server
                if err != nil {
                        return nil, err
                }
-               // verif := mtls.NewPeerNameVerifier()
-               // verif.AllowFrom(mtls.Service, "devicetrust")
-               // verif.ConfigureServer(tlsConfig)
-               // tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
-               // tlsConfig.ClientCAs = x509.NewCertPool()
-               // ca, err := certutil.LoadCertificatesFromPEM("/etc/ssl/mtls/rootca.pem")
-               // if err == nil {
-               //      for _, cert := range ca {
-               //              tlsConfig.ClientCAs.AddCert(cert)
-               //      }
-               // }
                server.TLSConfig = tlsConfig
        }
 
        return server, nil
 }
 
-func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDescriptor) (http.HandlerFunc, error) {
-       transport := &http.Transport{}
-       samlSp := make(map[string]*samlsp.Middleware, 0)
-       spMu := &sync.Mutex{}
-       if sp.Backend.Identity != "" {
-               myIdentity := mtls.DefaultIdentity()
-               tlsConfig, err := myIdentity.TlsConfig(ctx)
+func (sp *SAMLServiceProvider) Metadata() (*saml.EntityDescriptor, error) {
+       var err error
+       sp.metadataOnce.Do(func() {
+               var idpMetadataUrl *url.URL
+               idpMetadataUrl, err = url.Parse(sp.IDP)
                if err != nil {
-                       return nil, err
+                       return
                }
 
-               verifier := mtls.NewPeerNameVerifier()
-               verifier.AllowFrom(mtls.Service, sp.Backend.Identity)
-               err = verifier.ConfigureClient(tlsConfig)
-               if err != nil {
-                       return nil, err
+               sp.metadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataUrl)
+       })
+
+       if err != nil {
+               sp.metadataOnce = sync.Once{}
+               return nil, err
+       }
+
+       return sp.metadata, nil
+}
+
+func (sp *SAMLServiceProvider) CertAndKey() (cert *x509.Certificate, pvk *rsa.PrivateKey, err error) {
+       sp.certKeyOnce.Do(func() {
+               if sp.EntityPrivateKey != "" {
+                       var loadedKey crypto.PrivateKey
+                       loadedKey, err = certutil.LoadPrivateKeyFromPEM(sp.EntityPrivateKey)
+                       if err != nil {
+                               return
+                       }
+                       var ok bool
+                       pvk, ok = loadedKey.(*rsa.PrivateKey)
+                       if !ok {
+                               err = fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
+                               return
+                       }
+               } else {
+                       // generate new RSA private key
+                       pvk, err = rsa.GenerateKey(saml.RandReader, 2048)
+                       if err != nil {
+                               return
+                       }
+               }
+               if sp.EntityCertificate != "" {
+                       certs, err := certutil.LoadCertificatesFromPEM(sp.EntityCertificate)
+                       if err != nil {
+                               return
+                       }
+                       cert = certs[0]
+               } else {
+                       // generate new self-signed X509 certificate
+                       serialBytes := make([]byte, 16)
+                       saml.RandReader.Read(serialBytes)
+                       serial := big.NewInt(0)
+                       serial.SetBytes(serialBytes)
+
+                       template := &x509.Certificate{
+                               Subject: pkix.Name{
+                                       CommonName: sp.EntityID,
+                               },
+                               SerialNumber:          serial,
+                               BasicConstraintsValid: true,
+                               ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
+                               KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
+                               IsCA:                  false,
+                               NotBefore:             time.Now(),
+                               NotAfter:              time.Now().Add(90 * 86400 * time.Second),
+                       }
+                       certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &pvk.PublicKey, pvk)
+                       if err != nil {
+                               return
+                       }
+                       cert, err = x509.ParseCertificate(certBytes)
+                       if err != nil {
+                               return
+                       }
                }
 
-               transport.TLSClientConfig = tlsConfig
+               sp.entityCert = cert
+               sp.entityKey = pvk
+       })
+
+       if err != nil {
+               sp.metadataOnce = sync.Once{}
+               return nil, nil, err
        }
-       client := &http.Client{
-               Transport: transport,
+
+       return sp.entityCert, sp.entityKey, nil
+}
+
+func (sp *SAMLServiceProvider) NewServiceProvider(host string) (*samlsp.Middleware, error) {
+       idpMetadata, err := sp.Metadata()
+       if err != nil {
+               return nil, err
+       }
+       cert, key, err := sp.CertAndKey()
+       if err != nil {
+               return nil, err
        }
+       return samlsp.New(samlsp.Options{
+               EntityID: sp.EntityID,
+               URL: url.URL{
+                       Scheme: "https",
+                       Host:   host,
+               },
+               Key:         key,
+               Certificate: cert,
+               IDPMetadata: idpMetadata,
+       })
+}
+
+func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) {
+       samlSp := make(map[string]*samlsp.Middleware, 0)
+       spMu := &sync.Mutex{}
 
        handle := func(w http.ResponseWriter, r *http.Request) {
                // ensure host header present
@@ -250,39 +386,47 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes
                        return
                }
 
-               // make sure the browser isn't trying to access the IdP - this can happen if the TLS session
-               // was reused because our certificate is also valid for the SSO URL.
-               for _, ssoDesc := range idpMetadata.IDPSSODescriptors {
-                       for _, ssoSvc := range ssoDesc.SingleSignOnServices {
-                               if loginUrl, err := url.Parse(ssoSvc.Location); err == nil {
-                                       if loginUrl.Host == host {
-                                               sp.writeError(w, http.StatusMisdirectedRequest,
-                                                       errors.New("Misdirected request: this is not the IDP you're looking for"))
+               // make sure this host is known
+               vhost, ok := sp.Listener.VirtualHosts[host]
+               if !ok {
+                       sp.writeError(w, http.StatusMisdirectedRequest,
+                               errors.New("Misdirected request: unknown virtual host"))
 
-                                               return
-                                       }
-                               }
-                       }
+                       return
                }
+
                // ensure we have SP instance
                spMu.Lock()
                if _, ok := samlSp[host]; !ok {
-                       provider, err := samlsp.New(samlsp.Options{
-                               EntityID: sp.Listener.EntityID,
-                               URL: url.URL{
-                                       Scheme: "https",
-                                       Host:   host,
-                               },
-                               Key:         sp.entityKey,
-                               Certificate: sp.entityCert,
-                               IDPMetadata: idpMetadata,
-                       })
+                       samlSettings := vhost.SAMLServiceProvider
+                       if samlSettings == nil {
+                               samlSettings = sp.Listener.SAMLServiceProvider
+                       }
+                       idpMetadata, err := samlSettings.Metadata()
                        if err != nil {
                                sp.writeError(w, http.StatusInternalServerError, err)
                                return
                        }
-
-                       samlSp[host] = provider
+                       // make sure the browser isn't trying to access the IdP - this can happen if the TLS session
+                       // was reused because our certificate is also valid for the SSO URL.
+                       for _, ssoDesc := range idpMetadata.IDPSSODescriptors {
+                               for _, ssoSvc := range ssoDesc.SingleSignOnServices {
+                                       if loginUrl, err := url.Parse(ssoSvc.Location); err == nil {
+                                               if loginUrl.Host == host {
+                                                       sp.writeError(w, http.StatusMisdirectedRequest,
+                                                               errors.New("Misdirected request: this is not the IDP you're looking for"))
+
+                                                       return
+                                               }
+                                       }
+                               }
+                       }
+                       middleware, err := samlSettings.NewServiceProvider(host)
+                       if err != nil {
+                               sp.writeError(w, http.StatusInternalServerError, err)
+                               return
+                       }
+                       samlSp[host] = middleware
                }
                spMu.Unlock()
 
@@ -302,7 +446,7 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes
 
                defaultRoute := true
                sp.logger.V(3).Debugf("checking for routes matching %s", r.URL)
-               for _, route := range sp.Listener.Routes {
+               for _, route := range vhost.Routes {
                        match := false
                        if route.Path != nil {
                                match = route.Path.Match(r.URL.Path)
@@ -341,10 +485,10 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes
 
                newReq := r.Clone(r.Context())
                newReq.URL.Scheme = "http"
-               if sp.Backend.Identity != "" {
+               if vhost.Backend.Identity != "" {
                        newReq.URL.Scheme = "https"
                }
-               newReq.URL.Host = net.JoinHostPort(sp.Backend.Host, strconv.Itoa(sp.Backend.Port))
+               newReq.URL.Host = net.JoinHostPort(vhost.Backend.Host, strconv.Itoa(vhost.Backend.Port))
                newReq.RequestURI = ""
 
                if swa, ok := session.(samlsp.SessionWithAttributes); ok {
@@ -385,6 +529,10 @@ func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDes
                }
 
                // proxy the request to the backend
+               client, err := vhost.Backend.Client()
+               if err != nil {
+                       sp.writeError(w, http.StatusInternalServerError, fmt.Errorf("error setting up connection to backend: %v", err))
+               }
                response, err := client.Do(newReq)
                if err != nil {
                        sp.writeError(w, http.StatusBadGateway, err)
index d473e2c7439126ba589d2491824b76248c9722e0..52be074327683c0f8f322c5faf77a58ac1938748 100644 (file)
@@ -3,36 +3,66 @@ package main
 import (
        "context"
        "flag"
+       "os"
        "os/signal"
        "syscall"
        "time"
 
+       "gopkg.in/yaml.v3"
+
        "go.fuhry.dev/runtime/http"
        "go.fuhry.dev/runtime/mtls"
        "go.fuhry.dev/runtime/utils/log"
-       "go.fuhry.dev/runtime/utils/stringmatch"
 )
 
 func main() {
        mtls.SetDefaultIdentity("authproxy")
 
-       sp := http.SAMLProxy{}
+       sp := &http.SAMLProxy{
+               Listener: http.SAMLListener{
+                       SAMLServiceProvider: &http.SAMLServiceProvider{},
+               },
+       }
+       vhost := &http.SAMLVirtualHost{
+               Backend: &http.SAMLBackend{},
+       }
+
+       loadConfig := func(arg string) error {
+               contents, err := os.ReadFile(arg)
+               if err != nil {
+                       return err
+               }
+
+               err = yaml.Unmarshal(contents, sp)
+               return err
+       }
+       addRoute := func(arg string) error {
+               route, err := http.RouteFromArg(arg)
+               if err != nil {
+                       return err
+               }
+               vhost.Routes = append(vhost.Routes, route)
+               return nil
+       }
+
+       vhostName := flag.String("vhost", "", "HTTP(S) hostname to serve")
+       flag.Func("config", "YAML file to load configuration from", loadConfig)
+       flag.Func("route", "Route rule in the format of auth:field:matcher:value\n"+
+               "  auth: required, optional\n"+
+               "  field: path\n"+
+               "  matcher: prefix, suffix, exact, contains, regexp\n"+
+               "  value: any string", addRoute)
        flag.StringVar(&sp.Listener.EntityID, "saml.entity-id", "", "entity ID of SAML service provider")
        flag.StringVar(&sp.Listener.IDP, "saml.idp.url", "", "URL to IdP metadata")
        flag.StringVar(&sp.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private")
-       flag.StringVar(&sp.Backend.Host, "backend.host", "127.0.0.1", "backend host")
-       flag.IntVar(&sp.Backend.Port, "backend.port", 0, "backend port")
-       flag.StringVar(&sp.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend")
+       flag.StringVar(&vhost.Backend.Host, "backend.host", "127.0.0.1", "backend host")
+       flag.IntVar(&vhost.Backend.Port, "backend.port", 0, "backend port")
+       flag.StringVar(&vhost.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend")
        listen := flag.String("listen", "[::]:8443", "address for auth proxy to listen on")
 
        flag.Parse()
 
-       sp.Listener.Routes = []http.Route{
-               {
-                       Auth: http.AuthOptional,
-                       Path: stringmatch.Exact("/trickortreat.html"),
-               },
-       }
+       sp.Listener.VirtualHosts[*vhostName] = vhost
 
        ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
        defer cancel()