import (
"context"
+ "crypto"
"crypto/rsa"
+ "crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
type authEnforcement uint
+type Route struct {
+ Auth authEnforcement
+ Path stringmatch.StringMatcher
+}
+
+type SAMLBackend struct {
+ Host string `yaml:"host"`
+ Port int `yaml:"port"`
+ Identity string `yaml:"mtls_id"`
+
+ client *http.Client
+ clientOnce sync.Once
+}
+
+type SAMLVirtualHost struct {
+ *SAMLServiceProvider `yaml:"saml"`
+
+ Backend *SAMLBackend `yaml:"backend"`
+ Routes []*Route `yaml:"routes"`
+}
+
+type SAMLServiceProvider struct {
+ EntityID string `yaml:"entity_id"`
+ EntityCertificate string `yaml:"entity_certificate"`
+ EntityPrivateKey string `yaml:"entity_key"`
+ IDP string `yaml:"idp"`
+
+ metadata *saml.EntityDescriptor
+ metadataOnce sync.Once
+
+ entityCert *x509.Certificate
+ entityKey *rsa.PrivateKey
+ certKeyOnce sync.Once
+}
+
+type SAMLListener struct {
+ *SAMLServiceProvider `yaml:"saml"`
+
+ Certificate string `yaml:"cert"`
+ VirtualHosts map[string]*SAMLVirtualHost `yaml:"virtual_hosts"`
+}
+
+type SAMLProxy struct {
+ Listener SAMLListener `yaml:"listener"`
+
+ logger log.Logger
+}
+
const (
AuthRequired authEnforcement = iota
AuthOptional
}
// UnmarshalYAML implements yaml.Unmarshaler
-func (ae *authEnforcement) UnmarshalYAML(node *yaml.Node) error {
- if node.Kind != yaml.ScalarNode {
- return errors.New("yaml field of type authEnforcement must be a string")
+func (r *Route) UnmarshalYAML(node *yaml.Node) error {
+ var rawNode struct {
+ Auth string `yaml:"auth"`
+ Path *stringmatch.MatchRule `yaml:"path"`
}
- switch node.Value {
+ if err := node.Decode(&rawNode); err != nil {
+ return err
+ }
+
+ switch rawNode.Auth {
case "required":
- *ae = AuthRequired
+ r.Auth = AuthRequired
case "optional":
- *ae = AuthOptional
+ r.Auth = AuthOptional
default:
- return fmt.Errorf("invalid auth enforcement string value: %s", node.Value)
+ return fmt.Errorf("error unmarshaling route: invalid auth enforcement string value: %s", node.Value)
+ }
+
+ if rawNode.Path != nil {
+ m, err := rawNode.Path.Matcher()
+ if err != nil {
+ return fmt.Errorf("error unmarshaling route: invalid path matcher: %v", err)
+ }
+ r.Path = m
+ } else {
+ return errors.New("error unmarshaling route: exactly one of (`path`) must be specified")
}
return nil
}
-type Route struct {
- Auth authEnforcement
- Path stringmatch.StringMatcher
-}
+// RouteFromArg implements the 3rd argument to flag.Func.
+//
+// It parses a string in the format of auth:field:match_mode:value, returning a Route if
+// it parses successfully.
+func RouteFromArg(arg string) (*Route, error) {
+ parts := strings.SplitN(arg, ":", 4)
+ if len(parts) != 4 {
+ return nil, fmt.Errorf("invalid route spec: %q", arg)
+ }
+ a, f, t, v := parts[0], parts[1], parts[2], parts[3]
+ var auth authEnforcement
+ switch strings.ToLower(a) {
+ case "r", "req", "required":
+ auth = AuthRequired
+ case "o", "opt", "optional":
+ auth = AuthOptional
+ default:
+ return nil, fmt.Errorf("invalid auth setting: %q", a)
+ }
-type SAMLListener struct {
- EntityID string `yaml:"entity_id"`
- EntityCertificate string `yaml:"entity_certificate"`
- EntityPrivateKey string `yaml:"entity_key"`
- IDP string `yaml:"idp"`
- Certificate string `yaml:"cert"`
- Routes []Route `yaml:"routes"`
-}
+ route := &Route{
+ Auth: auth,
+ }
-type SAMLBackend struct {
- Host string `yaml:"host"`
- Port int `yaml:"port"`
- Identity string `yaml:"mtls_id"`
-}
+ match := stringmatch.MatchRule{
+ Mode: t,
+ Value: v,
+ }
+ m, err := match.Matcher()
+ if err != nil {
+ return nil, err
+ }
-type SAMLProxy struct {
- Listener SAMLListener `yaml:"listener"`
- Backend SAMLBackend `yaml:"backend"`
+ switch strings.ToLower(f) {
+ case "p", "path":
+ route.Path = m
+ default:
+ return nil, fmt.Errorf("invalid match field: %q", f)
+ }
- logger log.Logger
- entityCert *x509.Certificate
- entityKey *rsa.PrivateKey
+ return route, nil
}
-func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
- if sp.logger == nil {
- sp.logger = log.Default().WithPrefix("SAMLProxy")
- }
+// Client returns an HTTP client for making requests to the backend.
+func (b *SAMLBackend) Client() (*http.Client, error) {
+ var err error
+ b.clientOnce.Do(func() {
+ transport := &http.Transport{}
+ var tlsConfig *tls.Config
+
+ if b.Identity != "" {
+ myIdentity := mtls.DefaultIdentity()
+ tlsConfig, err = myIdentity.TlsConfig(context.Background())
+ if err != nil {
+ return
+ }
+
+ verifier := mtls.NewPeerNameVerifier()
+ verifier.AllowFrom(mtls.Service, b.Identity)
+ err = verifier.ConfigureClient(tlsConfig)
+ if err != nil {
+ return
+ }
- idpMetadataUrl, err := url.Parse(sp.Listener.IDP)
+ transport.TLSClientConfig = tlsConfig
+ }
+
+ client := &http.Client{
+ Transport: transport,
+ }
+
+ b.client = client
+ })
if err != nil {
return nil, err
}
- idpMetadata, err := samlsp.FetchMetadata(ctx, http.DefaultClient, *idpMetadataUrl)
- if err != nil {
- return nil, err
+ return b.client, nil
+}
+
+// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host
+// and other settings.
+func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
+ var _ yaml.Unmarshaler = &Route{}
+
+ if sp.logger == nil {
+ sp.logger = log.Default().WithPrefix("SAMLProxy")
}
- handler, err := sp.newHandler(ctx, idpMetadata)
+ handler, err := sp.newHandler()
if err != nil {
return nil, err
}
- if sp.Listener.EntityPrivateKey != "" {
- pvk, err := certutil.LoadPrivateKeyFromPEM(sp.Listener.EntityPrivateKey)
- if err != nil {
- return nil, err
- }
- rsaKey, ok := pvk.(*rsa.PrivateKey)
- if !ok {
- return nil, fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
- }
- sp.entityKey = rsaKey
- } else {
- // generate new RSA private key
- sp.entityKey, err = rsa.GenerateKey(saml.RandReader, 2048)
- if err != nil {
- return nil, err
- }
- }
- if sp.Listener.EntityCertificate != "" {
- certs, err := certutil.LoadCertificatesFromPEM(sp.Listener.EntityCertificate)
- if err != nil {
- return nil, err
- }
- sp.entityCert = certs[0]
- } else {
- // generate new self-signed X509 certificate
- serialBytes := make([]byte, 16)
- saml.RandReader.Read(serialBytes)
- serial := big.NewInt(0)
- serial.SetBytes(serialBytes)
-
- template := &x509.Certificate{
- Subject: pkix.Name{
- CommonName: sp.Listener.EntityID,
- },
- SerialNumber: serial,
- BasicConstraintsValid: true,
- ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
- KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
- IsCA: false,
- NotBefore: time.Now(),
- NotAfter: time.Now().Add(90 * 86400 * time.Second),
- }
- certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &sp.entityKey.PublicKey, sp.entityKey)
- if err != nil {
- return nil, err
- }
- cert, err := x509.ParseCertificate(certBytes)
- if err != nil {
- return nil, err
- }
- sp.entityCert = cert
- }
-
+ lm := log.NewLoggingMiddlewareWithLogger(handler, sp.logger)
server := &http.Server{
- Handler: handler,
+ Handler: lm.HandlerFunc(),
}
if sp.Listener.Certificate != "" {
if err != nil {
return nil, err
}
- // verif := mtls.NewPeerNameVerifier()
- // verif.AllowFrom(mtls.Service, "devicetrust")
- // verif.ConfigureServer(tlsConfig)
- // tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
- // tlsConfig.ClientCAs = x509.NewCertPool()
- // ca, err := certutil.LoadCertificatesFromPEM("/etc/ssl/mtls/rootca.pem")
- // if err == nil {
- // for _, cert := range ca {
- // tlsConfig.ClientCAs.AddCert(cert)
- // }
- // }
server.TLSConfig = tlsConfig
}
return server, nil
}
-func (sp *SAMLProxy) newHandler(ctx context.Context, idpMetadata *saml.EntityDescriptor) (http.HandlerFunc, error) {
- transport := &http.Transport{}
- samlSp := make(map[string]*samlsp.Middleware, 0)
- spMu := &sync.Mutex{}
- if sp.Backend.Identity != "" {
- myIdentity := mtls.DefaultIdentity()
- tlsConfig, err := myIdentity.TlsConfig(ctx)
+func (sp *SAMLServiceProvider) Metadata() (*saml.EntityDescriptor, error) {
+ var err error
+ sp.metadataOnce.Do(func() {
+ var idpMetadataUrl *url.URL
+ idpMetadataUrl, err = url.Parse(sp.IDP)
if err != nil {
- return nil, err
+ return
}
- verifier := mtls.NewPeerNameVerifier()
- verifier.AllowFrom(mtls.Service, sp.Backend.Identity)
- err = verifier.ConfigureClient(tlsConfig)
- if err != nil {
- return nil, err
+ sp.metadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataUrl)
+ })
+
+ if err != nil {
+ sp.metadataOnce = sync.Once{}
+ return nil, err
+ }
+
+ return sp.metadata, nil
+}
+
+func (sp *SAMLServiceProvider) CertAndKey() (cert *x509.Certificate, pvk *rsa.PrivateKey, err error) {
+ sp.certKeyOnce.Do(func() {
+ if sp.EntityPrivateKey != "" {
+ var loadedKey crypto.PrivateKey
+ loadedKey, err = certutil.LoadPrivateKeyFromPEM(sp.EntityPrivateKey)
+ if err != nil {
+ return
+ }
+ var ok bool
+ pvk, ok = loadedKey.(*rsa.PrivateKey)
+ if !ok {
+ err = fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
+ return
+ }
+ } else {
+ // generate new RSA private key
+ pvk, err = rsa.GenerateKey(saml.RandReader, 2048)
+ if err != nil {
+ return
+ }
+ }
+ if sp.EntityCertificate != "" {
+ certs, err := certutil.LoadCertificatesFromPEM(sp.EntityCertificate)
+ if err != nil {
+ return
+ }
+ cert = certs[0]
+ } else {
+ // generate new self-signed X509 certificate
+ serialBytes := make([]byte, 16)
+ saml.RandReader.Read(serialBytes)
+ serial := big.NewInt(0)
+ serial.SetBytes(serialBytes)
+
+ template := &x509.Certificate{
+ Subject: pkix.Name{
+ CommonName: sp.EntityID,
+ },
+ SerialNumber: serial,
+ BasicConstraintsValid: true,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
+ IsCA: false,
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(90 * 86400 * time.Second),
+ }
+ certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &pvk.PublicKey, pvk)
+ if err != nil {
+ return
+ }
+ cert, err = x509.ParseCertificate(certBytes)
+ if err != nil {
+ return
+ }
}
- transport.TLSClientConfig = tlsConfig
+ sp.entityCert = cert
+ sp.entityKey = pvk
+ })
+
+ if err != nil {
+ sp.metadataOnce = sync.Once{}
+ return nil, nil, err
}
- client := &http.Client{
- Transport: transport,
+
+ return sp.entityCert, sp.entityKey, nil
+}
+
+func (sp *SAMLServiceProvider) NewServiceProvider(host string) (*samlsp.Middleware, error) {
+ idpMetadata, err := sp.Metadata()
+ if err != nil {
+ return nil, err
+ }
+ cert, key, err := sp.CertAndKey()
+ if err != nil {
+ return nil, err
}
+ return samlsp.New(samlsp.Options{
+ EntityID: sp.EntityID,
+ URL: url.URL{
+ Scheme: "https",
+ Host: host,
+ },
+ Key: key,
+ Certificate: cert,
+ IDPMetadata: idpMetadata,
+ })
+}
+
+func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) {
+ samlSp := make(map[string]*samlsp.Middleware, 0)
+ spMu := &sync.Mutex{}
handle := func(w http.ResponseWriter, r *http.Request) {
// ensure host header present
return
}
- // make sure the browser isn't trying to access the IdP - this can happen if the TLS session
- // was reused because our certificate is also valid for the SSO URL.
- for _, ssoDesc := range idpMetadata.IDPSSODescriptors {
- for _, ssoSvc := range ssoDesc.SingleSignOnServices {
- if loginUrl, err := url.Parse(ssoSvc.Location); err == nil {
- if loginUrl.Host == host {
- sp.writeError(w, http.StatusMisdirectedRequest,
- errors.New("Misdirected request: this is not the IDP you're looking for"))
+ // make sure this host is known
+ vhost, ok := sp.Listener.VirtualHosts[host]
+ if !ok {
+ sp.writeError(w, http.StatusMisdirectedRequest,
+ errors.New("Misdirected request: unknown virtual host"))
- return
- }
- }
- }
+ return
}
+
// ensure we have SP instance
spMu.Lock()
if _, ok := samlSp[host]; !ok {
- provider, err := samlsp.New(samlsp.Options{
- EntityID: sp.Listener.EntityID,
- URL: url.URL{
- Scheme: "https",
- Host: host,
- },
- Key: sp.entityKey,
- Certificate: sp.entityCert,
- IDPMetadata: idpMetadata,
- })
+ samlSettings := vhost.SAMLServiceProvider
+ if samlSettings == nil {
+ samlSettings = sp.Listener.SAMLServiceProvider
+ }
+ idpMetadata, err := samlSettings.Metadata()
if err != nil {
sp.writeError(w, http.StatusInternalServerError, err)
return
}
-
- samlSp[host] = provider
+ // make sure the browser isn't trying to access the IdP - this can happen if the TLS session
+ // was reused because our certificate is also valid for the SSO URL.
+ for _, ssoDesc := range idpMetadata.IDPSSODescriptors {
+ for _, ssoSvc := range ssoDesc.SingleSignOnServices {
+ if loginUrl, err := url.Parse(ssoSvc.Location); err == nil {
+ if loginUrl.Host == host {
+ sp.writeError(w, http.StatusMisdirectedRequest,
+ errors.New("Misdirected request: this is not the IDP you're looking for"))
+
+ return
+ }
+ }
+ }
+ }
+ middleware, err := samlSettings.NewServiceProvider(host)
+ if err != nil {
+ sp.writeError(w, http.StatusInternalServerError, err)
+ return
+ }
+ samlSp[host] = middleware
}
spMu.Unlock()
defaultRoute := true
sp.logger.V(3).Debugf("checking for routes matching %s", r.URL)
- for _, route := range sp.Listener.Routes {
+ for _, route := range vhost.Routes {
match := false
if route.Path != nil {
match = route.Path.Match(r.URL.Path)
newReq := r.Clone(r.Context())
newReq.URL.Scheme = "http"
- if sp.Backend.Identity != "" {
+ if vhost.Backend.Identity != "" {
newReq.URL.Scheme = "https"
}
- newReq.URL.Host = net.JoinHostPort(sp.Backend.Host, strconv.Itoa(sp.Backend.Port))
+ newReq.URL.Host = net.JoinHostPort(vhost.Backend.Host, strconv.Itoa(vhost.Backend.Port))
newReq.RequestURI = ""
if swa, ok := session.(samlsp.SessionWithAttributes); ok {
}
// proxy the request to the backend
+ client, err := vhost.Backend.Client()
+ if err != nil {
+ sp.writeError(w, http.StatusInternalServerError, fmt.Errorf("error setting up connection to backend: %v", err))
+ }
response, err := client.Do(newReq)
if err != nil {
sp.writeError(w, http.StatusBadGateway, err)
import (
"context"
"flag"
+ "os"
"os/signal"
"syscall"
"time"
+ "gopkg.in/yaml.v3"
+
"go.fuhry.dev/runtime/http"
"go.fuhry.dev/runtime/mtls"
"go.fuhry.dev/runtime/utils/log"
- "go.fuhry.dev/runtime/utils/stringmatch"
)
func main() {
mtls.SetDefaultIdentity("authproxy")
- sp := http.SAMLProxy{}
+ sp := &http.SAMLProxy{
+ Listener: http.SAMLListener{
+ SAMLServiceProvider: &http.SAMLServiceProvider{},
+ },
+ }
+ vhost := &http.SAMLVirtualHost{
+ Backend: &http.SAMLBackend{},
+ }
+
+ loadConfig := func(arg string) error {
+ contents, err := os.ReadFile(arg)
+ if err != nil {
+ return err
+ }
+
+ err = yaml.Unmarshal(contents, sp)
+ return err
+ }
+ addRoute := func(arg string) error {
+ route, err := http.RouteFromArg(arg)
+ if err != nil {
+ return err
+ }
+ vhost.Routes = append(vhost.Routes, route)
+ return nil
+ }
+
+ vhostName := flag.String("vhost", "", "HTTP(S) hostname to serve")
+ flag.Func("config", "YAML file to load configuration from", loadConfig)
+ flag.Func("route", "Route rule in the format of auth:field:matcher:value\n"+
+ " auth: required, optional\n"+
+ " field: path\n"+
+ " matcher: prefix, suffix, exact, contains, regexp\n"+
+ " value: any string", addRoute)
flag.StringVar(&sp.Listener.EntityID, "saml.entity-id", "", "entity ID of SAML service provider")
flag.StringVar(&sp.Listener.IDP, "saml.idp.url", "", "URL to IdP metadata")
flag.StringVar(&sp.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private")
- flag.StringVar(&sp.Backend.Host, "backend.host", "127.0.0.1", "backend host")
- flag.IntVar(&sp.Backend.Port, "backend.port", 0, "backend port")
- flag.StringVar(&sp.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend")
+ flag.StringVar(&vhost.Backend.Host, "backend.host", "127.0.0.1", "backend host")
+ flag.IntVar(&vhost.Backend.Port, "backend.port", 0, "backend port")
+ flag.StringVar(&vhost.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend")
listen := flag.String("listen", "[::]:8443", "address for auth proxy to listen on")
flag.Parse()
- sp.Listener.Routes = []http.Route{
- {
- Auth: http.AuthOptional,
- Path: stringmatch.Exact("/trickortreat.html"),
- },
- }
+ sp.Listener.VirtualHosts[*vhostName] = vhost
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()