From 61929de458bb3568312c0c95bad1126d9c5013b4 Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Sat, 14 Mar 2026 20:04:24 -0400 Subject: [PATCH] [mtls] support configs for different environments, embed default configs - Only try the identity drivers and paths listed in the config - Use different default config based on environment heuristics - Remove all hardcoded file paths, make file provider paths fully configurable Supporting changes: - Make IdentityLoaderFunc public - Rename type: `IdentityClass` -> `VerifyMatchBy` --- cmd/echo_client/main.go | 6 +- mtls/BUILD.bazel | 8 ++ mtls/config.go | 214 ++++++++++++++++++++++++++++++++ mtls/configs/prod.yaml | 18 +++ mtls/configs/user.yaml | 13 ++ mtls/identity.go | 37 ++++-- mtls/lazy_identity.go | 33 +---- mtls/provider_file.go | 124 +++++++++++++++++- mtls/provider_keychain_macos.go | 9 ++ mtls/provider_tpm2_pkcs11.go | 23 +++- mtls/verify_names.go | 26 ++-- mtls/verify_roots.go | 33 ++--- 12 files changed, 477 insertions(+), 67 deletions(-) create mode 100644 mtls/config.go create mode 100644 mtls/configs/prod.yaml create mode 100644 mtls/configs/user.yaml diff --git a/cmd/echo_client/main.go b/cmd/echo_client/main.go index 8b1f98c..37496a9 100644 --- a/cmd/echo_client/main.go +++ b/cmd/echo_client/main.go @@ -18,7 +18,11 @@ func main() { flag.Parse() logger := log.Default().WithPrefix("EchoClient") - clientId := mtls.DefaultIdentity() + mtlsConfig := mtls.DefaultConfig() + clientId, err := mtlsConfig.DefaultIdentity() + if err != nil { + logger.Panic(err) + } client, err := grpc.NewGrpcClient(ctx, "echo", clientId) if err != nil { logger.Panic(err) diff --git a/mtls/BUILD.bazel b/mtls/BUILD.bazel index 08351c5..9fa9e32 100644 --- a/mtls/BUILD.bazel +++ b/mtls/BUILD.bazel @@ -3,6 +3,7 @@ load("@rules_go//go:def.bzl", "go_library") go_library( name = "mtls", srcs = [ + "config.go", "identity.go", "lazy_identity.go", "pkcs11.go", @@ -19,6 +20,10 @@ go_library( "@rules_go//go/platform:linux_amd64": True, "//conditions:default": False, }), + embedsrcs = [ + "configs/prod.yaml", + "configs/user.yaml", + ], importpath = "go.fuhry.dev/runtime/mtls", visibility = ["//visibility:public"], deps = [ @@ -27,8 +32,11 @@ go_library( "//mtls/fsnotify", "//utils/fsutil", "//utils/hashset", + "//utils/hostname", "//utils/log", + "//utils/subst", "@com_github_thalesignite_crypto11//:crypto11", + "@in_gopkg_yaml_v3//:yaml_v3", ] + select({ "@rules_go//go/platform:darwin": [ "//utils/stringmatch", diff --git a/mtls/config.go b/mtls/config.go new file mode 100644 index 0000000..bf42052 --- /dev/null +++ b/mtls/config.go @@ -0,0 +1,214 @@ +package mtls + +import ( + "embed" + "flag" + "fmt" + "os" + "strings" + "sync" + + "gopkg.in/yaml.v3" + + "go.fuhry.dev/runtime/constants" + "go.fuhry.dev/runtime/utils/fsutil" + "go.fuhry.dev/runtime/utils/hostname" +) + +type Config interface { + NewIdentity(cls PrincipalClass, princ string) (Identity, error) + DefaultIdentity() (Identity, error) +} + +type providerFactory interface { + New(node *yaml.Node) (IdentityLoaderFunc, error) +} + +type providerNode struct { + Loader IdentityLoaderFunc +} + +type configCertificates struct { + Providers []*providerNode `yaml:"providers"` +} + +type cacheKey struct { + cls PrincipalClass + princ string +} + +type config struct { + Certificates *configCertificates `yaml:"certificates"` + + cache map[cacheKey]Identity + cacheMu sync.Mutex +} + +const ( + embedPrefix = "embed:" +) + +var providerFactories = make(map[string]providerFactory) + +//go:embed configs +var defaultConfigs embed.FS + +var ( + defaultConfig *config + defaultConfigOnce sync.Once + defaultConfigPath = "" + defaultEmbeddedConfigPath = "embed:configs/prod.yaml" + defaultFilesystemConfigPath = "/etc/" + constants.OrgSlug + "/mtls.yaml" +) + +func (c *config) NewIdentity(cls PrincipalClass, princ string) (Identity, error) { + if i, ok := c.cacheGet(cls, princ); ok { + return i, nil + } + if cls == AnonymousPrincipal { + id := &anonymousIdentity{} + c.cacheStore(cls, princ, id) + return id, nil + } + for _, p := range c.Certificates.Providers { + if cp, err := p.Loader(cls, princ); err == nil { + id := &substantiatedIdentity{cp} + + c.cacheStore(cls, princ, id) + return id, nil + } + } + + return nil, fmt.Errorf( + "no provider was able to satisfy the request for identity with class %s, principal %s", + cls.String(), princ) +} + +func (c *config) cacheGet(cls PrincipalClass, princ string) (Identity, bool) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + if c.cache == nil { + c.cache = make(map[cacheKey]Identity) + } + k := cacheKey{cls, princ} + + if entry, ok := c.cache[k]; ok { + if identityIsValid(entry) { + return entry, true + } + } + + return nil, false +} + +func (c *config) cacheStore(cls PrincipalClass, princ string, id Identity) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + if c.cache == nil { + c.cache = make(map[cacheKey]Identity) + } + k := cacheKey{cls, princ} + + c.cache[k] = id +} + +func (c *config) DefaultIdentity() (Identity, error) { + id := ParseIdentity(defaultMtlsIdentity) + return c.NewIdentity(id.Class(), id.Name()) +} + +func RegisterProviderFactory(name string, fac providerFactory) { + if _, ok := providerFactories[name]; ok { + panic(fmt.Sprintf("provider factory %q already registered", name)) + } + + providerFactories[name] = fac +} + +func (p *providerNode) UnmarshalYAML(node *yaml.Node) error { + cfg := new(struct { + Engine string `yaml:"engine"` + IgnoreFailures bool `yaml:"ignore_failures"` + }) + + if err := node.Decode(cfg); err != nil { + return err + } + + fac, ok := providerFactories[cfg.Engine] + if !ok { + if cfg.IgnoreFailures { + p.Loader = nullIdentityLoader + return nil + } + return fmt.Errorf("unknown provider engine: %q", cfg.Engine) + } + + loader, err := fac.New(node) + if err != nil { + if cfg.IgnoreFailures { + p.Loader = nullIdentityLoader + return nil + } + return err + } + + p.Loader = loader + return nil +} + +func nullIdentityLoader(cls PrincipalClass, princ string) (CertificateProvider, error) { + return nil, ErrUnsupportedClass +} + +func initConfig() { + defaultConfigPath = defaultEmbeddedConfigPath + if hostname.IsLikelyUserMachine() { + defaultConfigPath = "embed:configs/user.yaml" + } + if err := fsutil.FileExistsAndIsReadable(defaultFilesystemConfigPath); err == nil { + defaultConfigPath = defaultFilesystemConfigPath + } + + flag.StringVar(&defaultConfigPath, "mtls.config", defaultConfigPath, "path to the default mtls configuration file") +} + +func DefaultConfig() Config { + if !flag.Parsed() { + panic("mtls: DefaultConfig() called before flags were parsed") + } + + defaultConfigOnce.Do(func() { + var contents []byte + + logger.Noticef("using defaultConfigPath = %q", defaultConfigPath) + if defaultConfigPath == "" { + } + if strings.HasPrefix(defaultConfigPath, embedPrefix) { + configPath := defaultConfigPath[len(embedPrefix):] + if c, err := defaultConfigs.ReadFile(configPath); err != nil { + logger.Panicf("failed to read embedded config %q: %v", configPath, err) + } else { + contents = c + } + } else { + if c, err := os.ReadFile(defaultConfigPath); err != nil { + logger.Panicf("failed to read mtls config %q from filesystem: %v", + defaultConfigPath, err) + } else { + contents = c + } + } + + defaultConfig = &config{} + if err := yaml.Unmarshal(contents, defaultConfig); err != nil { + logger.Panicf("error loading mtls config %q: %v", defaultConfigPath, err) + } + + logger.Noticef("loaded %d providers from config %q", len(defaultConfig.Certificates.Providers), defaultConfigPath) + }) + + return defaultConfig +} diff --git a/mtls/configs/prod.yaml b/mtls/configs/prod.yaml new file mode 100644 index 0000000..bf28bb5 --- /dev/null +++ b/mtls/configs/prod.yaml @@ -0,0 +1,18 @@ +certificates: + providers: + - engine: file + basedir: /etc/ssl/mtls/${principal} + - engine: file + basedir: /etc/ssl/mtls/${principal} + leaf: fullchain.pem + chain: fullchain.pem + key: privkey.pem + ca: ../rootca.pem + - engine: file + basedir: /run/secrets/spiffe.io + leaf: tls.crt + chain: tls.crt + key: tls.key + ca: ca.pem + - engine: mint + ignore_failures: true diff --git a/mtls/configs/user.yaml b/mtls/configs/user.yaml new file mode 100644 index 0000000..ab86bc2 --- /dev/null +++ b/mtls/configs/user.yaml @@ -0,0 +1,13 @@ +certificates: + providers: + - engine: file + basedir: ${dirname:${nullor:${env:STEP_PERSONAL_CERTIFICATE}:-${nullor:${env:HOME}:-/home/${principal}}/.step/authorities/test/user.crt}} + leaf: ${basename:${nullor:${env:STEP_PERSONAL_CERTIFICATE}:-${nullor:${env:HOME}:-/home/${nullor:${env:USER}:-${principal}}}/.step/authorities/test/user.crt}} + key: ${basename:${nullor:${env:STEP_PERSONAL_PRIVATE_KEY}:-${nullor:${env:HOME}:-/home/${nullor:${env:USER}:-${principal}}}/.step/authorities/test/user.key}} + ca: root_ca.crt + - engine: tpm2_pkcs11 + ignore_failures: true + - engine: macos_keychain + ignore_failures: true + - engine: mint + ignore_failures: true \ No newline at end of file diff --git a/mtls/identity.go b/mtls/identity.go index 94d4a71..030bde4 100644 --- a/mtls/identity.go +++ b/mtls/identity.go @@ -8,7 +8,9 @@ import ( "os/user" "strings" + "go.fuhry.dev/runtime/constants" "go.fuhry.dev/runtime/mtls/certutil" + "go.fuhry.dev/runtime/utils/hostname" "go.fuhry.dev/runtime/utils/log" ) @@ -30,13 +32,15 @@ func (c PrincipalClass) String() string { return "user" case SSLCertificatePrincipal: return "tls" + case AnonymousPrincipal: + return "anonymous" } panic("invalid PrincipalClass") } const ( - defaultDefaultIdentity = "host" + defaultDefaultIdentity = constants.OrgSlug ) var ( @@ -68,18 +72,18 @@ type stubIdentity struct { var _ Identity = &stubIdentity{} -type identityLoaderFunc func(cls PrincipalClass, name string) (CertificateProvider, error) +type IdentityLoaderFunc func(cls PrincipalClass, name string) (CertificateProvider, error) type identityDriver struct { name string - load identityLoaderFunc + load IdentityLoaderFunc } var identityDrivers []*identityDriver var ErrUnsupportedClass = errors.New("this driver does not support loading this identity class") -func RegisterIdentityDriver(name string, load identityLoaderFunc) { +func RegisterIdentityDriver(name string, load IdentityLoaderFunc) { driver := &identityDriver{ name: name, load: load, @@ -145,21 +149,26 @@ func (id *substantiatedIdentity) IsValid() bool { } func identityIsValid(id CertificatePrimitive) bool { + err := validateIdentity(id) + return err == nil +} + +func validateIdentity(id CertificatePrimitive) error { cert, err := id.LeafCertificate() if err != nil { - return false + return fmt.Errorf("error reading leaf certificate: %v", err) } if err := certutil.ValidNow(cert); err != nil { - return false + return err } pkey, err := id.PrivateKey() if err != nil || pkey == nil { - return false + return fmt.Errorf("error reading private key: %v", err) } - return true + return nil } func (id *substantiatedIdentity) Equals(other Identity) bool { @@ -289,11 +298,23 @@ func NewSSLCertificate(certName string) Identity { } func init() { + if defaultMtlsIdentity == "" { + if hostname.IsLikelyUserMachine() { + if whoami, err := user.Current(); err == nil { + defaultMtlsIdentity = "user." + whoami.Username + } + } + if defaultMtlsIdentity == "" { + defaultMtlsIdentity = defaultDefaultIdentity + } + } flag.StringVar(&defaultMtlsIdentity, "mtls.id", defaultMtlsIdentity, "mTLS identity to use when not overridden by the application") // identityCache = make(map[string]*serviceIdentity, 0) logger = log.Default().WithPrefix("mtls") + + initConfig() } // SetDefaultIdentity sets the mtls id that is used when -mtls.id is not specified diff --git a/mtls/lazy_identity.go b/mtls/lazy_identity.go index 26557fa..1f66c7e 100644 --- a/mtls/lazy_identity.go +++ b/mtls/lazy_identity.go @@ -49,6 +49,10 @@ func (id *lazyIdentity) Equals(other Identity) bool { func (id *lazyIdentity) IsValid() bool { id.tryLoad() + if id.cls == AnonymousPrincipal { + return true + } + if id.cp != nil { return identityIsValid(id.cp) } @@ -64,35 +68,10 @@ func (id *lazyIdentity) tryLoad() { return } - if id.cls == AnonymousPrincipal { - id.cp = Anonymous() + if cp, err := DefaultConfig().NewIdentity(id.cls, id.name); err == nil { + id.cp = cp return } - if id.cls == SSLCertificatePrincipal { - if cert := NewSSLCertificate(id.name); cert.IsValid() { - id.cp = cert - } - return - } - for _, driver := range identityDrivers { - logger.V(1).Infof("trying driver %s to load %s identity %s", driver.name, id.cls, id.name) - if cert, err := driver.load(id.cls, id.name); err == nil { - subst := &substantiatedIdentity{cert} - if subst.Name() == id.name && subst.Class() == id.cls { - logger.V(2).Infof("driver %s reports it loaded identity %s:%s", driver.name, id.cls.String(), id.name) - id.cp = cert - return - } else { - logger.V(2).Warnf( - "driver %s successfully loaded a certificate, but it doesn't match what "+ - "we expected: class %s (expected)/%s (got); name %s (expected)/%s (got)", - driver.name, id.cls.String(), subst.Class().String(), - id.name, subst.Name()) - } - } else { - logger.V(2).Warnf("driver %s failed to load %s identity %s: %+v", driver.name, id.cls.String(), id.name, err) - } - } } func (id *lazyIdentity) LeafCertificate() (*x509.Certificate, error) { diff --git a/mtls/provider_file.go b/mtls/provider_file.go index 9312370..d5c7948 100644 --- a/mtls/provider_file.go +++ b/mtls/provider_file.go @@ -5,15 +5,19 @@ import ( "crypto" "crypto/tls" "crypto/x509" + "errors" "flag" "fmt" "os" "path" "sync" + "gopkg.in/yaml.v3" + "go.fuhry.dev/runtime/mtls/certutil" "go.fuhry.dev/runtime/mtls/fsnotify" "go.fuhry.dev/runtime/utils/fsutil" + "go.fuhry.dev/runtime/utils/subst" ) type FileBackedCertificate struct { @@ -42,6 +46,8 @@ type fileBackedRoots struct { initOnce sync.Once } +type fileProviderFactory struct{} + const ( defaultMtlsRootPath = "/etc/ssl/mtls" ) @@ -477,7 +483,7 @@ func (c *FileBackedCertificate) TlsConfig(ctx context.Context) (*tls.Config, err defer c.mu.Unlock() if c.tlsConfig != nil { - return c.tlsConfig, nil + return c.tlsConfig.Clone(), nil } fsnotify.NotifyPath(c.LeafPath, c.notifyEvent) @@ -490,7 +496,7 @@ func (c *FileBackedCertificate) TlsConfig(ctx context.Context) (*tls.Config, err GetClientCertificate: c.GetClientCertificate, } - return c.tlsConfig, nil + return c.tlsConfig.Clone(), nil } func (c *FileBackedCertificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -574,6 +580,118 @@ func RootPaths() []string { return rootPathsCopy } +func (f *fileProviderFactory) New(node *yaml.Node) (IdentityLoaderFunc, error) { + const ( + defaultLeaf = "fullchain.pem" + defaultKey = "privkey.pem" + defaultCA = "../rootca.pem" + ) + cfg := new(struct { + BaseDir string `yaml:"basedir"` + Leaf string `yaml:"leaf"` + Key string `yaml:"key"` + Chain string `yaml:"chain"` + CA string `yaml:"ca"` + }) + + if err := node.Decode(cfg); err != nil { + return nil, err + } + + if cfg.BaseDir == "" { + return nil, errors.New("basedir is required for file engine") + } + + if cfg.Leaf == "" { + cfg.Leaf = defaultLeaf + } + + if cfg.Key == "" { + cfg.Key = defaultKey + } + + if cfg.CA == "" { + cfg.CA = defaultCA + } + + return func(cls PrincipalClass, princ string) (CertificateProvider, error) { + kv := subst.KV{ + "class": cls.String(), + "principal": princ, + } + basedir, err := subst.Eval(kv, cfg.BaseDir) + if err != nil { + return nil, fmt.Errorf("failed to evaluate basedir expression: %q", err) + } + leaf, err := subst.Eval(kv, cfg.Leaf) + if err != nil { + return nil, fmt.Errorf("failed to evaluate leaf expression: %q", err) + } + key, err := subst.Eval(kv, cfg.Key) + if err != nil { + return nil, fmt.Errorf("failed to evaluate key expression: %q", err) + } + chain, err := subst.Eval(kv, cfg.Chain) + if err != nil { + return nil, fmt.Errorf("failed to evaluate chain expression: %q", err) + } + root, err := subst.Eval(kv, cfg.CA) + if err != nil { + return nil, fmt.Errorf("failed to evaluate ca expression: %q", err) + } + + logger.V(3).Debugf("fileProviderFactory: basedir expanded %q -> %q", cfg.BaseDir, basedir) + logger.V(3).Debugf("fileProviderFactory: leaf expanded %q -> %q", cfg.Leaf, leaf) + logger.V(3).Debugf("fileProviderFactory: key expanded %q -> %q", cfg.Key, key) + logger.V(3).Debugf("fileProviderFactory: chain expanded %q -> %q", cfg.Chain, chain) + logger.V(3).Debugf("fileProviderFactory: root expanded %q -> %q", cfg.CA, root) + + leaf = fsnotify.RealPath(path.Join(basedir, leaf)) + key = fsnotify.RealPath(path.Join(basedir, key)) + chain = fsnotify.RealPath(path.Join(basedir, chain)) + root = fsnotify.RealPath(path.Join(basedir, root)) + + if err := fsutil.FileExistsAndIsReadable(leaf); err != nil { + return nil, fmt.Errorf("while checking leaf certificate: %v", err) + } + if err := fsutil.FileExistsAndIsReadable(key); err != nil { + return nil, fmt.Errorf("while checking private key: %v", err) + } + if err := fsutil.FileExistsAndIsReadable(chain); err != nil { + chain = leaf + } + if err := fsutil.FileExistsAndIsReadable(root); err != nil { + return nil, fmt.Errorf("while checking root certificate: %v", err) + } + fbc := &FileBackedCertificate{ + LeafPath: leaf, + PrivateKeyPath: key, + IntermediatesPath: chain, + RootPath: root, + } + + if err := validateIdentity(fbc); err != nil { + logger.V(2).Noticef("loaded certificate failed validation: %v", err) + return nil, fmt.Errorf( + "loaded certificate is not valid: %v", err) + } + + id := &substantiatedIdentity{fbc} + if id.Class() != cls || id.Name() != princ { + return nil, fmt.Errorf( + "loaded identity did not match the one requested:\n requested: %s %s\n loaded: %s %s", + cls.String(), princ, id.Class().String(), id.Name()) + } + + logger.V(2).Infof( + "fileProviderFactory: loaded %s identity %q:\n"+ + " leaf: %s\n key: %s\n chain: %s\n root: %s", + cls.String(), princ, leaf, key, chain, root) + + return fbc, nil + }, nil +} + func init() { defaultRootCAFile = fmt.Sprintf("%s/rootca.pem", defaultMtlsRootPath) defaultIntermediateCAFile = fmt.Sprintf("%s/ca.pem", defaultMtlsRootPath) @@ -628,4 +746,6 @@ func init() { RegisterRootDriver("file_csi_spiffe_altname", func() (RootsPrimitive, error) { return csiSpiffeRootsAltName, nil }) + + RegisterProviderFactory("file", &fileProviderFactory{}) } diff --git a/mtls/provider_keychain_macos.go b/mtls/provider_keychain_macos.go index bf9c0ec..cbce997 100644 --- a/mtls/provider_keychain_macos.go +++ b/mtls/provider_keychain_macos.go @@ -24,6 +24,7 @@ import ( "go.fuhry.dev/runtime/utils/hashset" "go.fuhry.dev/runtime/utils/log" "go.fuhry.dev/runtime/utils/stringmatch" + "gopkg.in/yaml.v3" ) type macosKeychainCertificate struct { @@ -60,6 +61,8 @@ type kcSignerOpts struct { hash crypto.Hash } +type macosKeychainProviderFactory struct{} + func (kso *kcSignerOpts) HashFunc() crypto.Hash { return kso.hash } @@ -445,11 +448,17 @@ func (kcr *macosKeychainRoots) IntermediateCertificates() ([]*x509.Certificate, return getMtlsIntermediatesFromMacKeychain() } +func (f *macosKeychainProviderFactory) New(_ *yaml.Node) (identityLoaderFunc, error) { + return NewCertificateFromMacKeychain, nil +} + func init() { kcLogger = log.WithPrefix("mtls.macOSKeychain") RegisterIdentityDriver("macos_keychain", NewCertificateFromMacKeychain) + RegisterProviderFactory("macos_keychain", &macosKeychainProviderFactory{}) + RegisterRootDriver("macos_keychain", func() (RootsPrimitive, error) { return &macosKeychainRoots{}, nil }) diff --git a/mtls/provider_tpm2_pkcs11.go b/mtls/provider_tpm2_pkcs11.go index 48ad0fa..c755aef 100644 --- a/mtls/provider_tpm2_pkcs11.go +++ b/mtls/provider_tpm2_pkcs11.go @@ -12,10 +12,10 @@ import ( "go.fuhry.dev/runtime/constants" "go.fuhry.dev/runtime/mtls/certutil" - "go.fuhry.dev/runtime/utils/log" + "gopkg.in/yaml.v3" ) -var tpmLogger = log.WithPrefix("mtls.provider_tpm2_pkcs11") +type tpmProviderFactory struct{} type TPMBackedCertificate struct { CertificatePrimitive @@ -112,6 +112,25 @@ func (c *TPMBackedCertificate) NewDialContextFunc() DialContextFunc { return MakeDialContextFunc(c) } +func (f *tpmProviderFactory) New(_ *yaml.Node) (IdentityLoaderFunc, error) { + return func(cls PrincipalClass, serviceName string) (CertificateProvider, error) { + if cls != ServicePrincipal || serviceName != constants.DeviceTrustPrincipal { + return nil, ErrUnsupportedClass + } + + cert, err := NewTPMBackedCertificate() + if err != nil { + return nil, err + } + + if _, err := cert.LeafCertificate(); err != nil { + return nil, err + } + + return cert, nil + }, nil +} + func init() { RegisterIdentityDriver("tpm2-pkcs11", func(cls PrincipalClass, serviceName string) (CertificateProvider, error) { if cls != ServicePrincipal || serviceName != constants.DeviceTrustPrincipal { diff --git a/mtls/verify_names.go b/mtls/verify_names.go index 4640f62..c926d39 100644 --- a/mtls/verify_names.go +++ b/mtls/verify_names.go @@ -12,16 +12,16 @@ import ( "go.fuhry.dev/runtime/utils/log" ) -type IdentityClass uint +type VerifyMatchBy uint type RemoteIdentity struct { - Class IdentityClass + Class VerifyMatchBy Domain string Principal string } const ( - Domain IdentityClass = iota + Domain VerifyMatchBy = iota Service User All @@ -33,7 +33,7 @@ const ( exprMtlsInternalIdenitty = `^(?P[A-Za-z0-9_-]+)\.(?P[A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)*)\.mtls\.internal$` ) -func (c IdentityClass) String() string { +func (c VerifyMatchBy) String() string { switch c { case Domain: return "domain" @@ -48,9 +48,7 @@ func (c IdentityClass) String() string { panic("invalid value for IdentityClass") } -var ( - spiffeServiceIdentity, spiffeUserIdentity, mtlsInternalIdentity *regexp.Regexp -) +var spiffeServiceIdentity, spiffeUserIdentity, mtlsInternalIdentity *regexp.Regexp func init() { spiffeServiceIdentity = regexp.MustCompile(exprSpiffeServiceIdentity) @@ -79,7 +77,7 @@ type MTLSPeerVerifier interface { // // If IdentityClass is Service or User, the variadic arguments are the service names or // usernames that are permitted to connect. - AllowFrom(IdentityClass, ...string) + AllowFrom(VerifyMatchBy, ...string) // VerifyPeerCert conforms to the function prototype for `tls.Config.VerifyConnection`. // @@ -90,13 +88,13 @@ type MTLSPeerVerifier interface { type mtlsPeerVerifier struct { MTLSPeerVerifier - allowedPrincipals map[IdentityClass]*hashset.HashSet[string] + allowedPrincipals map[VerifyMatchBy]*hashset.HashSet[string] log log.Logger } func NewPeerNameVerifier() MTLSPeerVerifier { cv := &mtlsPeerVerifier{ - allowedPrincipals: make(map[IdentityClass]*hashset.HashSet[string]), + allowedPrincipals: make(map[VerifyMatchBy]*hashset.HashSet[string]), log: log.WithPrefix("MTLSPeerVerifier"), } @@ -184,7 +182,7 @@ func (cv *mtlsPeerVerifier) VerifyPeerCert(peerCert *x509.Certificate, verifyOpt return fmt.Errorf("none of the names in this certificate are allowed") } -func (cv *mtlsPeerVerifier) AllowFrom(class IdentityClass, principals ...string) { +func (cv *mtlsPeerVerifier) AllowFrom(class VerifyMatchBy, principals ...string) { if len(principals) < 1 && class != All { return } @@ -245,7 +243,7 @@ func (cv *mtlsPeerVerifier) checkName(name string) error { func ParseRemoteIdentity(name string) (*RemoteIdentity, error) { exps := []struct { expr *regexp.Regexp - class IdentityClass + class VerifyMatchBy }{ { expr: spiffeServiceIdentity, @@ -293,6 +291,10 @@ func (i *RemoteIdentity) ToDnsName() string { i.Principal, i.Class.String(), i.Domain) + case Service: + return fmt.Sprintf("%s.%s.mtls.internal", + i.Principal, + i.Domain) default: return fmt.Sprintf("%s.%s.%s.mtls.internal", i.Principal, diff --git a/mtls/verify_roots.go b/mtls/verify_roots.go index b52159f..d996958 100644 --- a/mtls/verify_roots.go +++ b/mtls/verify_roots.go @@ -168,17 +168,8 @@ func verifyMTLSCertificateChain(leafCert *x509.Certificate, intermediates []*x50 return nil } -func NewVerifyMTLSPeerCertificateFunc() (tlsVerifyPeerCertificatesFunc, error) { - vo, err := NewMTLSVerifyOpts() - if err != nil { - return nil, err - } - - return NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tls.RequireAndVerifyClientCert), nil -} - func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions, clientAuth tls.ClientAuthType) tlsVerifyPeerCertificatesFunc { - return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + verify := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) (error, *x509.Certificate) { var leafCert *x509.Certificate intermediates := make([]*x509.Certificate, 0) @@ -190,22 +181,34 @@ func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions, clientAuth intermediates = append(intermediates, cert) } else { if leafCert != nil { - return ErrMultipleCertificatesPresented + return ErrMultipleCertificatesPresented, cert } leafCert = cert } } else { - return ErrCertificateParseFailed + return ErrCertificateParseFailed, nil } } if leafCert == nil { if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven { - return nil + return nil, nil } - return ErrNoCertificatePresented + return ErrNoCertificatePresented, nil } - return verifyMTLSCertificateChain(leafCert, intermediates, vo) + return verifyMTLSCertificateChain(leafCert, intermediates, vo), leafCert + } + + return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + err, leaf := verify(rawCerts, verifiedChains) + if err != nil { + if leaf != nil { + logger.Noticef("rejected peer cert %v: %v", leaf.Subject, err) + } else { + logger.Noticef("rejected peer connection (no cert presented): %v; clientAuth is %v", err, clientAuth.String()) + } + } + return err } } -- 2.52.0