]> go.fuhry.dev Git - runtime.git/commitdiff
[mtls] support configs for different environments, embed default configs
authorDan Fuhry <dan@fuhry.com>
Sun, 15 Mar 2026 00:04:24 +0000 (20:04 -0400)
committerDan Fuhry <dan@fuhry.com>
Sun, 15 Mar 2026 01:17:54 +0000 (21:17 -0400)
- Only try the identity drivers and paths listed in the config
- Use different default config based on environment heuristics
- Remove all hardcoded file paths, make file provider paths fully configurable

Supporting changes:
- Make IdentityLoaderFunc public
- Rename type: `IdentityClass` -> `VerifyMatchBy`

12 files changed:
cmd/echo_client/main.go
mtls/BUILD.bazel
mtls/config.go [new file with mode: 0644]
mtls/configs/prod.yaml [new file with mode: 0644]
mtls/configs/user.yaml [new file with mode: 0644]
mtls/identity.go
mtls/lazy_identity.go
mtls/provider_file.go
mtls/provider_keychain_macos.go
mtls/provider_tpm2_pkcs11.go
mtls/verify_names.go
mtls/verify_roots.go

index 8b1f98cbbaa2c8eb92e8d7dda9989903889e2992..37496a93b5601b49ca23cbb5a202fad4c7021483 100644 (file)
@@ -18,7 +18,11 @@ func main() {
        flag.Parse()
        logger := log.Default().WithPrefix("EchoClient")
 
-       clientId := mtls.DefaultIdentity()
+       mtlsConfig := mtls.DefaultConfig()
+       clientId, err := mtlsConfig.DefaultIdentity()
+       if err != nil {
+               logger.Panic(err)
+       }
        client, err := grpc.NewGrpcClient(ctx, "echo", clientId)
        if err != nil {
                logger.Panic(err)
index 08351c5593d34aba05197bb4ff24734846379630..9fa9e324dfee83c5a827121936a5b614341a40d9 100644 (file)
@@ -3,6 +3,7 @@ load("@rules_go//go:def.bzl", "go_library")
 go_library(
     name = "mtls",
     srcs = [
+        "config.go",
         "identity.go",
         "lazy_identity.go",
         "pkcs11.go",
@@ -19,6 +20,10 @@ go_library(
         "@rules_go//go/platform:linux_amd64": True,
         "//conditions:default": False,
     }),
+    embedsrcs = [
+        "configs/prod.yaml",
+        "configs/user.yaml",
+    ],
     importpath = "go.fuhry.dev/runtime/mtls",
     visibility = ["//visibility:public"],
     deps = [
@@ -27,8 +32,11 @@ go_library(
         "//mtls/fsnotify",
         "//utils/fsutil",
         "//utils/hashset",
+        "//utils/hostname",
         "//utils/log",
+        "//utils/subst",
         "@com_github_thalesignite_crypto11//:crypto11",
+        "@in_gopkg_yaml_v3//:yaml_v3",
     ] + select({
         "@rules_go//go/platform:darwin": [
             "//utils/stringmatch",
diff --git a/mtls/config.go b/mtls/config.go
new file mode 100644 (file)
index 0000000..bf42052
--- /dev/null
@@ -0,0 +1,214 @@
+package mtls
+
+import (
+       "embed"
+       "flag"
+       "fmt"
+       "os"
+       "strings"
+       "sync"
+
+       "gopkg.in/yaml.v3"
+
+       "go.fuhry.dev/runtime/constants"
+       "go.fuhry.dev/runtime/utils/fsutil"
+       "go.fuhry.dev/runtime/utils/hostname"
+)
+
+type Config interface {
+       NewIdentity(cls PrincipalClass, princ string) (Identity, error)
+       DefaultIdentity() (Identity, error)
+}
+
+type providerFactory interface {
+       New(node *yaml.Node) (IdentityLoaderFunc, error)
+}
+
+type providerNode struct {
+       Loader IdentityLoaderFunc
+}
+
+type configCertificates struct {
+       Providers []*providerNode `yaml:"providers"`
+}
+
+type cacheKey struct {
+       cls   PrincipalClass
+       princ string
+}
+
+type config struct {
+       Certificates *configCertificates `yaml:"certificates"`
+
+       cache   map[cacheKey]Identity
+       cacheMu sync.Mutex
+}
+
+const (
+       embedPrefix = "embed:"
+)
+
+var providerFactories = make(map[string]providerFactory)
+
+//go:embed configs
+var defaultConfigs embed.FS
+
+var (
+       defaultConfig               *config
+       defaultConfigOnce           sync.Once
+       defaultConfigPath           = ""
+       defaultEmbeddedConfigPath   = "embed:configs/prod.yaml"
+       defaultFilesystemConfigPath = "/etc/" + constants.OrgSlug + "/mtls.yaml"
+)
+
+func (c *config) NewIdentity(cls PrincipalClass, princ string) (Identity, error) {
+       if i, ok := c.cacheGet(cls, princ); ok {
+               return i, nil
+       }
+       if cls == AnonymousPrincipal {
+               id := &anonymousIdentity{}
+               c.cacheStore(cls, princ, id)
+               return id, nil
+       }
+       for _, p := range c.Certificates.Providers {
+               if cp, err := p.Loader(cls, princ); err == nil {
+                       id := &substantiatedIdentity{cp}
+
+                       c.cacheStore(cls, princ, id)
+                       return id, nil
+               }
+       }
+
+       return nil, fmt.Errorf(
+               "no provider was able to satisfy the request for identity with class %s, principal %s",
+               cls.String(), princ)
+}
+
+func (c *config) cacheGet(cls PrincipalClass, princ string) (Identity, bool) {
+       c.cacheMu.Lock()
+       defer c.cacheMu.Unlock()
+
+       if c.cache == nil {
+               c.cache = make(map[cacheKey]Identity)
+       }
+       k := cacheKey{cls, princ}
+
+       if entry, ok := c.cache[k]; ok {
+               if identityIsValid(entry) {
+                       return entry, true
+               }
+       }
+
+       return nil, false
+}
+
+func (c *config) cacheStore(cls PrincipalClass, princ string, id Identity) {
+       c.cacheMu.Lock()
+       defer c.cacheMu.Unlock()
+
+       if c.cache == nil {
+               c.cache = make(map[cacheKey]Identity)
+       }
+       k := cacheKey{cls, princ}
+
+       c.cache[k] = id
+}
+
+func (c *config) DefaultIdentity() (Identity, error) {
+       id := ParseIdentity(defaultMtlsIdentity)
+       return c.NewIdentity(id.Class(), id.Name())
+}
+
+func RegisterProviderFactory(name string, fac providerFactory) {
+       if _, ok := providerFactories[name]; ok {
+               panic(fmt.Sprintf("provider factory %q already registered", name))
+       }
+
+       providerFactories[name] = fac
+}
+
+func (p *providerNode) UnmarshalYAML(node *yaml.Node) error {
+       cfg := new(struct {
+               Engine         string `yaml:"engine"`
+               IgnoreFailures bool   `yaml:"ignore_failures"`
+       })
+
+       if err := node.Decode(cfg); err != nil {
+               return err
+       }
+
+       fac, ok := providerFactories[cfg.Engine]
+       if !ok {
+               if cfg.IgnoreFailures {
+                       p.Loader = nullIdentityLoader
+                       return nil
+               }
+               return fmt.Errorf("unknown provider engine: %q", cfg.Engine)
+       }
+
+       loader, err := fac.New(node)
+       if err != nil {
+               if cfg.IgnoreFailures {
+                       p.Loader = nullIdentityLoader
+                       return nil
+               }
+               return err
+       }
+
+       p.Loader = loader
+       return nil
+}
+
+func nullIdentityLoader(cls PrincipalClass, princ string) (CertificateProvider, error) {
+       return nil, ErrUnsupportedClass
+}
+
+func initConfig() {
+       defaultConfigPath = defaultEmbeddedConfigPath
+       if hostname.IsLikelyUserMachine() {
+               defaultConfigPath = "embed:configs/user.yaml"
+       }
+       if err := fsutil.FileExistsAndIsReadable(defaultFilesystemConfigPath); err == nil {
+               defaultConfigPath = defaultFilesystemConfigPath
+       }
+
+       flag.StringVar(&defaultConfigPath, "mtls.config", defaultConfigPath, "path to the default mtls configuration file")
+}
+
+func DefaultConfig() Config {
+       if !flag.Parsed() {
+               panic("mtls: DefaultConfig() called before flags were parsed")
+       }
+
+       defaultConfigOnce.Do(func() {
+               var contents []byte
+
+               logger.Noticef("using defaultConfigPath = %q", defaultConfigPath)
+               if defaultConfigPath == "" {
+               }
+               if strings.HasPrefix(defaultConfigPath, embedPrefix) {
+                       configPath := defaultConfigPath[len(embedPrefix):]
+                       if c, err := defaultConfigs.ReadFile(configPath); err != nil {
+                               logger.Panicf("failed to read embedded config %q: %v", configPath, err)
+                       } else {
+                               contents = c
+                       }
+               } else {
+                       if c, err := os.ReadFile(defaultConfigPath); err != nil {
+                               logger.Panicf("failed to read mtls config %q from filesystem: %v",
+                                       defaultConfigPath, err)
+                       } else {
+                               contents = c
+                       }
+               }
+
+               defaultConfig = &config{}
+               if err := yaml.Unmarshal(contents, defaultConfig); err != nil {
+                       logger.Panicf("error loading mtls config %q: %v", defaultConfigPath, err)
+               }
+
+               logger.Noticef("loaded %d providers from config %q", len(defaultConfig.Certificates.Providers), defaultConfigPath)
+       })
+
+       return defaultConfig
+}
diff --git a/mtls/configs/prod.yaml b/mtls/configs/prod.yaml
new file mode 100644 (file)
index 0000000..bf28bb5
--- /dev/null
@@ -0,0 +1,18 @@
+certificates:
+  providers:
+    - engine: file
+      basedir: /etc/ssl/mtls/${principal}
+    - engine: file
+      basedir: /etc/ssl/mtls/${principal}
+      leaf: fullchain.pem
+      chain: fullchain.pem
+      key: privkey.pem
+      ca: ../rootca.pem
+    - engine: file
+      basedir: /run/secrets/spiffe.io
+      leaf: tls.crt
+      chain: tls.crt
+      key: tls.key
+      ca: ca.pem
+    - engine: mint
+      ignore_failures: true
diff --git a/mtls/configs/user.yaml b/mtls/configs/user.yaml
new file mode 100644 (file)
index 0000000..ab86bc2
--- /dev/null
@@ -0,0 +1,13 @@
+certificates:
+  providers:
+    - engine: file
+      basedir: ${dirname:${nullor:${env:STEP_PERSONAL_CERTIFICATE}:-${nullor:${env:HOME}:-/home/${principal}}/.step/authorities/test/user.crt}}
+      leaf: ${basename:${nullor:${env:STEP_PERSONAL_CERTIFICATE}:-${nullor:${env:HOME}:-/home/${nullor:${env:USER}:-${principal}}}/.step/authorities/test/user.crt}}
+      key: ${basename:${nullor:${env:STEP_PERSONAL_PRIVATE_KEY}:-${nullor:${env:HOME}:-/home/${nullor:${env:USER}:-${principal}}}/.step/authorities/test/user.key}}
+      ca: root_ca.crt
+    - engine: tpm2_pkcs11
+      ignore_failures: true
+    - engine: macos_keychain
+      ignore_failures: true
+    - engine: mint
+      ignore_failures: true
\ No newline at end of file
index 94d4a71ed77b3667dc148320c721534c5d26597d..030bde425c73b9e7e4496e662c141d0a064155d1 100644 (file)
@@ -8,7 +8,9 @@ import (
        "os/user"
        "strings"
 
+       "go.fuhry.dev/runtime/constants"
        "go.fuhry.dev/runtime/mtls/certutil"
+       "go.fuhry.dev/runtime/utils/hostname"
        "go.fuhry.dev/runtime/utils/log"
 )
 
@@ -30,13 +32,15 @@ func (c PrincipalClass) String() string {
                return "user"
        case SSLCertificatePrincipal:
                return "tls"
+       case AnonymousPrincipal:
+               return "anonymous"
        }
 
        panic("invalid PrincipalClass")
 }
 
 const (
-       defaultDefaultIdentity = "host"
+       defaultDefaultIdentity = constants.OrgSlug
 )
 
 var (
@@ -68,18 +72,18 @@ type stubIdentity struct {
 
 var _ Identity = &stubIdentity{}
 
-type identityLoaderFunc func(cls PrincipalClass, name string) (CertificateProvider, error)
+type IdentityLoaderFunc func(cls PrincipalClass, name string) (CertificateProvider, error)
 
 type identityDriver struct {
        name string
-       load identityLoaderFunc
+       load IdentityLoaderFunc
 }
 
 var identityDrivers []*identityDriver
 
 var ErrUnsupportedClass = errors.New("this driver does not support loading this identity class")
 
-func RegisterIdentityDriver(name string, load identityLoaderFunc) {
+func RegisterIdentityDriver(name string, load IdentityLoaderFunc) {
        driver := &identityDriver{
                name: name,
                load: load,
@@ -145,21 +149,26 @@ func (id *substantiatedIdentity) IsValid() bool {
 }
 
 func identityIsValid(id CertificatePrimitive) bool {
+       err := validateIdentity(id)
+       return err == nil
+}
+
+func validateIdentity(id CertificatePrimitive) error {
        cert, err := id.LeafCertificate()
        if err != nil {
-               return false
+               return fmt.Errorf("error reading leaf certificate: %v", err)
        }
 
        if err := certutil.ValidNow(cert); err != nil {
-               return false
+               return err
        }
 
        pkey, err := id.PrivateKey()
        if err != nil || pkey == nil {
-               return false
+               return fmt.Errorf("error reading private key: %v", err)
        }
 
-       return true
+       return nil
 }
 
 func (id *substantiatedIdentity) Equals(other Identity) bool {
@@ -289,11 +298,23 @@ func NewSSLCertificate(certName string) Identity {
 }
 
 func init() {
+       if defaultMtlsIdentity == "" {
+               if hostname.IsLikelyUserMachine() {
+                       if whoami, err := user.Current(); err == nil {
+                               defaultMtlsIdentity = "user." + whoami.Username
+                       }
+               }
+               if defaultMtlsIdentity == "" {
+                       defaultMtlsIdentity = defaultDefaultIdentity
+               }
+       }
        flag.StringVar(&defaultMtlsIdentity, "mtls.id", defaultMtlsIdentity, "mTLS identity to use when not overridden by the application")
 
        // identityCache = make(map[string]*serviceIdentity, 0)
 
        logger = log.Default().WithPrefix("mtls")
+
+       initConfig()
 }
 
 // SetDefaultIdentity sets the mtls id that is used when -mtls.id is not specified
index 26557fa5055eaad89e7ac235284daa7fcec59074..1f66c7ec27a970121f078c8c42c233b549b90b6d 100644 (file)
@@ -49,6 +49,10 @@ func (id *lazyIdentity) Equals(other Identity) bool {
 func (id *lazyIdentity) IsValid() bool {
        id.tryLoad()
 
+       if id.cls == AnonymousPrincipal {
+               return true
+       }
+
        if id.cp != nil {
                return identityIsValid(id.cp)
        }
@@ -64,35 +68,10 @@ func (id *lazyIdentity) tryLoad() {
                return
        }
 
-       if id.cls == AnonymousPrincipal {
-               id.cp = Anonymous()
+       if cp, err := DefaultConfig().NewIdentity(id.cls, id.name); err == nil {
+               id.cp = cp
                return
        }
-       if id.cls == SSLCertificatePrincipal {
-               if cert := NewSSLCertificate(id.name); cert.IsValid() {
-                       id.cp = cert
-               }
-               return
-       }
-       for _, driver := range identityDrivers {
-               logger.V(1).Infof("trying driver %s to load %s identity %s", driver.name, id.cls, id.name)
-               if cert, err := driver.load(id.cls, id.name); err == nil {
-                       subst := &substantiatedIdentity{cert}
-                       if subst.Name() == id.name && subst.Class() == id.cls {
-                               logger.V(2).Infof("driver %s reports it loaded identity %s:%s", driver.name, id.cls.String(), id.name)
-                               id.cp = cert
-                               return
-                       } else {
-                               logger.V(2).Warnf(
-                                       "driver %s successfully loaded a certificate, but it doesn't match what "+
-                                               "we expected: class %s (expected)/%s (got); name %s (expected)/%s (got)",
-                                       driver.name, id.cls.String(), subst.Class().String(),
-                                       id.name, subst.Name())
-                       }
-               } else {
-                       logger.V(2).Warnf("driver %s failed to load %s identity %s: %+v", driver.name, id.cls.String(), id.name, err)
-               }
-       }
 }
 
 func (id *lazyIdentity) LeafCertificate() (*x509.Certificate, error) {
index 93123700b02a0cd7d3b4a3c69df931190a0e3876..d5c7948772fa241e8e4f0e375b5131c09ed8c539 100644 (file)
@@ -5,15 +5,19 @@ import (
        "crypto"
        "crypto/tls"
        "crypto/x509"
+       "errors"
        "flag"
        "fmt"
        "os"
        "path"
        "sync"
 
+       "gopkg.in/yaml.v3"
+
        "go.fuhry.dev/runtime/mtls/certutil"
        "go.fuhry.dev/runtime/mtls/fsnotify"
        "go.fuhry.dev/runtime/utils/fsutil"
+       "go.fuhry.dev/runtime/utils/subst"
 )
 
 type FileBackedCertificate struct {
@@ -42,6 +46,8 @@ type fileBackedRoots struct {
        initOnce sync.Once
 }
 
+type fileProviderFactory struct{}
+
 const (
        defaultMtlsRootPath = "/etc/ssl/mtls"
 )
@@ -477,7 +483,7 @@ func (c *FileBackedCertificate) TlsConfig(ctx context.Context) (*tls.Config, err
        defer c.mu.Unlock()
 
        if c.tlsConfig != nil {
-               return c.tlsConfig, nil
+               return c.tlsConfig.Clone(), nil
        }
 
        fsnotify.NotifyPath(c.LeafPath, c.notifyEvent)
@@ -490,7 +496,7 @@ func (c *FileBackedCertificate) TlsConfig(ctx context.Context) (*tls.Config, err
                GetClientCertificate: c.GetClientCertificate,
        }
 
-       return c.tlsConfig, nil
+       return c.tlsConfig.Clone(), nil
 }
 
 func (c *FileBackedCertificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -574,6 +580,118 @@ func RootPaths() []string {
        return rootPathsCopy
 }
 
+func (f *fileProviderFactory) New(node *yaml.Node) (IdentityLoaderFunc, error) {
+       const (
+               defaultLeaf = "fullchain.pem"
+               defaultKey  = "privkey.pem"
+               defaultCA   = "../rootca.pem"
+       )
+       cfg := new(struct {
+               BaseDir string `yaml:"basedir"`
+               Leaf    string `yaml:"leaf"`
+               Key     string `yaml:"key"`
+               Chain   string `yaml:"chain"`
+               CA      string `yaml:"ca"`
+       })
+
+       if err := node.Decode(cfg); err != nil {
+               return nil, err
+       }
+
+       if cfg.BaseDir == "" {
+               return nil, errors.New("basedir is required for file engine")
+       }
+
+       if cfg.Leaf == "" {
+               cfg.Leaf = defaultLeaf
+       }
+
+       if cfg.Key == "" {
+               cfg.Key = defaultKey
+       }
+
+       if cfg.CA == "" {
+               cfg.CA = defaultCA
+       }
+
+       return func(cls PrincipalClass, princ string) (CertificateProvider, error) {
+               kv := subst.KV{
+                       "class":     cls.String(),
+                       "principal": princ,
+               }
+               basedir, err := subst.Eval(kv, cfg.BaseDir)
+               if err != nil {
+                       return nil, fmt.Errorf("failed to evaluate basedir expression: %q", err)
+               }
+               leaf, err := subst.Eval(kv, cfg.Leaf)
+               if err != nil {
+                       return nil, fmt.Errorf("failed to evaluate leaf expression: %q", err)
+               }
+               key, err := subst.Eval(kv, cfg.Key)
+               if err != nil {
+                       return nil, fmt.Errorf("failed to evaluate key expression: %q", err)
+               }
+               chain, err := subst.Eval(kv, cfg.Chain)
+               if err != nil {
+                       return nil, fmt.Errorf("failed to evaluate chain expression: %q", err)
+               }
+               root, err := subst.Eval(kv, cfg.CA)
+               if err != nil {
+                       return nil, fmt.Errorf("failed to evaluate ca expression: %q", err)
+               }
+
+               logger.V(3).Debugf("fileProviderFactory: basedir expanded %q -> %q", cfg.BaseDir, basedir)
+               logger.V(3).Debugf("fileProviderFactory: leaf expanded %q -> %q", cfg.Leaf, leaf)
+               logger.V(3).Debugf("fileProviderFactory: key expanded %q -> %q", cfg.Key, key)
+               logger.V(3).Debugf("fileProviderFactory: chain expanded %q -> %q", cfg.Chain, chain)
+               logger.V(3).Debugf("fileProviderFactory: root expanded %q -> %q", cfg.CA, root)
+
+               leaf = fsnotify.RealPath(path.Join(basedir, leaf))
+               key = fsnotify.RealPath(path.Join(basedir, key))
+               chain = fsnotify.RealPath(path.Join(basedir, chain))
+               root = fsnotify.RealPath(path.Join(basedir, root))
+
+               if err := fsutil.FileExistsAndIsReadable(leaf); err != nil {
+                       return nil, fmt.Errorf("while checking leaf certificate: %v", err)
+               }
+               if err := fsutil.FileExistsAndIsReadable(key); err != nil {
+                       return nil, fmt.Errorf("while checking private key: %v", err)
+               }
+               if err := fsutil.FileExistsAndIsReadable(chain); err != nil {
+                       chain = leaf
+               }
+               if err := fsutil.FileExistsAndIsReadable(root); err != nil {
+                       return nil, fmt.Errorf("while checking root certificate: %v", err)
+               }
+               fbc := &FileBackedCertificate{
+                       LeafPath:          leaf,
+                       PrivateKeyPath:    key,
+                       IntermediatesPath: chain,
+                       RootPath:          root,
+               }
+
+               if err := validateIdentity(fbc); err != nil {
+                       logger.V(2).Noticef("loaded certificate failed validation: %v", err)
+                       return nil, fmt.Errorf(
+                               "loaded certificate is not valid: %v", err)
+               }
+
+               id := &substantiatedIdentity{fbc}
+               if id.Class() != cls || id.Name() != princ {
+                       return nil, fmt.Errorf(
+                               "loaded identity did not match the one requested:\n  requested: %s %s\n  loaded: %s %s",
+                               cls.String(), princ, id.Class().String(), id.Name())
+               }
+
+               logger.V(2).Infof(
+                       "fileProviderFactory: loaded %s identity %q:\n"+
+                               "  leaf:  %s\n  key:   %s\n  chain: %s\n  root:  %s",
+                       cls.String(), princ, leaf, key, chain, root)
+
+               return fbc, nil
+       }, nil
+}
+
 func init() {
        defaultRootCAFile = fmt.Sprintf("%s/rootca.pem", defaultMtlsRootPath)
        defaultIntermediateCAFile = fmt.Sprintf("%s/ca.pem", defaultMtlsRootPath)
@@ -628,4 +746,6 @@ func init() {
        RegisterRootDriver("file_csi_spiffe_altname", func() (RootsPrimitive, error) {
                return csiSpiffeRootsAltName, nil
        })
+
+       RegisterProviderFactory("file", &fileProviderFactory{})
 }
index bf9c0ec9a5c5b81a7ea97ccb6c73e6124839de9a..cbce997165a54dce1e0ed3bdb1500eaf054c231c 100644 (file)
@@ -24,6 +24,7 @@ import (
        "go.fuhry.dev/runtime/utils/hashset"
        "go.fuhry.dev/runtime/utils/log"
        "go.fuhry.dev/runtime/utils/stringmatch"
+       "gopkg.in/yaml.v3"
 )
 
 type macosKeychainCertificate struct {
@@ -60,6 +61,8 @@ type kcSignerOpts struct {
        hash crypto.Hash
 }
 
+type macosKeychainProviderFactory struct{}
+
 func (kso *kcSignerOpts) HashFunc() crypto.Hash {
        return kso.hash
 }
@@ -445,11 +448,17 @@ func (kcr *macosKeychainRoots) IntermediateCertificates() ([]*x509.Certificate,
        return getMtlsIntermediatesFromMacKeychain()
 }
 
+func (f *macosKeychainProviderFactory) New(_ *yaml.Node) (identityLoaderFunc, error) {
+       return NewCertificateFromMacKeychain, nil
+}
+
 func init() {
        kcLogger = log.WithPrefix("mtls.macOSKeychain")
 
        RegisterIdentityDriver("macos_keychain", NewCertificateFromMacKeychain)
 
+       RegisterProviderFactory("macos_keychain", &macosKeychainProviderFactory{})
+
        RegisterRootDriver("macos_keychain", func() (RootsPrimitive, error) {
                return &macosKeychainRoots{}, nil
        })
index 48ad0fad6bd62011be815679b26a7acb8aa364ae..c755aef01f4e86328d2a20ed089437fcfddcea8b 100644 (file)
@@ -12,10 +12,10 @@ import (
 
        "go.fuhry.dev/runtime/constants"
        "go.fuhry.dev/runtime/mtls/certutil"
-       "go.fuhry.dev/runtime/utils/log"
+       "gopkg.in/yaml.v3"
 )
 
-var tpmLogger = log.WithPrefix("mtls.provider_tpm2_pkcs11")
+type tpmProviderFactory struct{}
 
 type TPMBackedCertificate struct {
        CertificatePrimitive
@@ -112,6 +112,25 @@ func (c *TPMBackedCertificate) NewDialContextFunc() DialContextFunc {
        return MakeDialContextFunc(c)
 }
 
+func (f *tpmProviderFactory) New(_ *yaml.Node) (IdentityLoaderFunc, error) {
+       return func(cls PrincipalClass, serviceName string) (CertificateProvider, error) {
+               if cls != ServicePrincipal || serviceName != constants.DeviceTrustPrincipal {
+                       return nil, ErrUnsupportedClass
+               }
+
+               cert, err := NewTPMBackedCertificate()
+               if err != nil {
+                       return nil, err
+               }
+
+               if _, err := cert.LeafCertificate(); err != nil {
+                       return nil, err
+               }
+
+               return cert, nil
+       }, nil
+}
+
 func init() {
        RegisterIdentityDriver("tpm2-pkcs11", func(cls PrincipalClass, serviceName string) (CertificateProvider, error) {
                if cls != ServicePrincipal || serviceName != constants.DeviceTrustPrincipal {
index 4640f62f7b768879f2836dd5009db161d735d833..c926d3962681e159fa266c61e4c2109dc99ef935 100644 (file)
@@ -12,16 +12,16 @@ import (
        "go.fuhry.dev/runtime/utils/log"
 )
 
-type IdentityClass uint
+type VerifyMatchBy uint
 
 type RemoteIdentity struct {
-       Class     IdentityClass
+       Class     VerifyMatchBy
        Domain    string
        Principal string
 }
 
 const (
-       Domain IdentityClass = iota
+       Domain VerifyMatchBy = iota
        Service
        User
        All
@@ -33,7 +33,7 @@ const (
        exprMtlsInternalIdenitty  = `^(?P<principal>[A-Za-z0-9_-]+)\.(?P<domain>[A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)*)\.mtls\.internal$`
 )
 
-func (c IdentityClass) String() string {
+func (c VerifyMatchBy) String() string {
        switch c {
        case Domain:
                return "domain"
@@ -48,9 +48,7 @@ func (c IdentityClass) String() string {
        panic("invalid value for IdentityClass")
 }
 
-var (
-       spiffeServiceIdentity, spiffeUserIdentity, mtlsInternalIdentity *regexp.Regexp
-)
+var spiffeServiceIdentity, spiffeUserIdentity, mtlsInternalIdentity *regexp.Regexp
 
 func init() {
        spiffeServiceIdentity = regexp.MustCompile(exprSpiffeServiceIdentity)
@@ -79,7 +77,7 @@ type MTLSPeerVerifier interface {
        //
        // If IdentityClass is Service or User, the variadic arguments are the service names or
        // usernames that are permitted to connect.
-       AllowFrom(IdentityClass, ...string)
+       AllowFrom(VerifyMatchBy, ...string)
 
        // VerifyPeerCert conforms to the function prototype for `tls.Config.VerifyConnection`.
        //
@@ -90,13 +88,13 @@ type MTLSPeerVerifier interface {
 type mtlsPeerVerifier struct {
        MTLSPeerVerifier
 
-       allowedPrincipals map[IdentityClass]*hashset.HashSet[string]
+       allowedPrincipals map[VerifyMatchBy]*hashset.HashSet[string]
        log               log.Logger
 }
 
 func NewPeerNameVerifier() MTLSPeerVerifier {
        cv := &mtlsPeerVerifier{
-               allowedPrincipals: make(map[IdentityClass]*hashset.HashSet[string]),
+               allowedPrincipals: make(map[VerifyMatchBy]*hashset.HashSet[string]),
                log:               log.WithPrefix("MTLSPeerVerifier"),
        }
 
@@ -184,7 +182,7 @@ func (cv *mtlsPeerVerifier) VerifyPeerCert(peerCert *x509.Certificate, verifyOpt
        return fmt.Errorf("none of the names in this certificate are allowed")
 }
 
-func (cv *mtlsPeerVerifier) AllowFrom(class IdentityClass, principals ...string) {
+func (cv *mtlsPeerVerifier) AllowFrom(class VerifyMatchBy, principals ...string) {
        if len(principals) < 1 && class != All {
                return
        }
@@ -245,7 +243,7 @@ func (cv *mtlsPeerVerifier) checkName(name string) error {
 func ParseRemoteIdentity(name string) (*RemoteIdentity, error) {
        exps := []struct {
                expr  *regexp.Regexp
-               class IdentityClass
+               class VerifyMatchBy
        }{
                {
                        expr:  spiffeServiceIdentity,
@@ -293,6 +291,10 @@ func (i *RemoteIdentity) ToDnsName() string {
                        i.Principal,
                        i.Class.String(),
                        i.Domain)
+       case Service:
+               return fmt.Sprintf("%s.%s.mtls.internal",
+                       i.Principal,
+                       i.Domain)
        default:
                return fmt.Sprintf("%s.%s.%s.mtls.internal",
                        i.Principal,
index b52159f62812f2d2ce85fef8a110612395d49b8d..d996958ee0d4206943c12039d76e9df015909a4e 100644 (file)
@@ -168,17 +168,8 @@ func verifyMTLSCertificateChain(leafCert *x509.Certificate, intermediates []*x50
        return nil
 }
 
-func NewVerifyMTLSPeerCertificateFunc() (tlsVerifyPeerCertificatesFunc, error) {
-       vo, err := NewMTLSVerifyOpts()
-       if err != nil {
-               return nil, err
-       }
-
-       return NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tls.RequireAndVerifyClientCert), nil
-}
-
 func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions, clientAuth tls.ClientAuthType) tlsVerifyPeerCertificatesFunc {
-       return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+       verify := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) (error, *x509.Certificate) {
                var leafCert *x509.Certificate
 
                intermediates := make([]*x509.Certificate, 0)
@@ -190,22 +181,34 @@ func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions, clientAuth
                                        intermediates = append(intermediates, cert)
                                } else {
                                        if leafCert != nil {
-                                               return ErrMultipleCertificatesPresented
+                                               return ErrMultipleCertificatesPresented, cert
                                        }
                                        leafCert = cert
                                }
                        } else {
-                               return ErrCertificateParseFailed
+                               return ErrCertificateParseFailed, nil
                        }
                }
 
                if leafCert == nil {
                        if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven {
-                               return nil
+                               return nil, nil
                        }
-                       return ErrNoCertificatePresented
+                       return ErrNoCertificatePresented, nil
                }
 
-               return verifyMTLSCertificateChain(leafCert, intermediates, vo)
+               return verifyMTLSCertificateChain(leafCert, intermediates, vo), leafCert
+       }
+
+       return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
+               err, leaf := verify(rawCerts, verifiedChains)
+               if err != nil {
+                       if leaf != nil {
+                               logger.Noticef("rejected peer cert %v: %v", leaf.Subject, err)
+                       } else {
+                               logger.Noticef("rejected peer connection (no cert presented): %v; clientAuth is %v", err, clientAuth.String())
+                       }
+               }
+               return err
        }
 }