]> go.fuhry.dev Git - runtime.git/commitdiff
[mtls] add lazy identities, loader improvements
authorDan Fuhry <dan@fuhry.com>
Wed, 19 Nov 2025 21:32:34 +0000 (16:32 -0500)
committerDan Fuhry <dan@fuhry.com>
Wed, 19 Nov 2025 21:32:34 +0000 (16:32 -0500)
- indicate principal class and name separately to loader drivers
- add LazyIdentity to delay calling drivers until credentials are needed

bazel/workspace-status
constants/BUILD.bazel
constants/constants_in.go
mtls/BUILD.bazel
mtls/certutil/certutil.go
mtls/identity.go
mtls/lazy_identity.go [new file with mode: 0644]
mtls/provider_file.go
mtls/provider_keychain_macos.go
mtls/provider_tpm2_pkcs11.go
mtls/verify_names.go

index b7641f6536bb04a5e4d594c2bf826034512cf27b..911a78ea9b76e0ac675388ee95709f5544fe1741 100755 (executable)
@@ -6,21 +6,22 @@ if test -r "$basedir/workspace-status.local"; then
        . "$basedir/workspace-status.local"
 fi
 
-ROOT_DOMAIN=${ROOT_DOMAIN:-"fuhry.dev"}
-DEFAULT_REGION=${DEFAULT_REGION:-"hq"}
-DEFAULT_HOST_DOMAIN=${DEFAULT_HOST_DOMAIN:-"${DEFAULT_REGION}.${ROOT_DOMAIN}"}
-SD_DOMAIN=${SD_DOMAIN:-"v.${ROOT_DOMAIN}"}
-WEB_SERVICES_DOMAIN=${WEB_SERVICES_DOMAIN:-"${ROOT_DOMAIN}"}
-MACHINES_HOST=${MACHINES_HOST:-"machines.${WEB_SERVICES_DOMAIN}"}
-MACHINES_MQTT_TOPIC=${MACHINES_MQTT_TOPIC:-"machines/events"}
-DBUS_PREFIX=${DBUS_PREFIX:-"dev.fuhry.runtime"}
-DBUS_PATH=${DBUS_PATH:-"/${DBUS_PREFIX//\./\/}"}
-ORG_NAME=${ORG_NAME:-"FooCorp"}
-ORG_SLUG=${ORG_SLUG:-"runtime"}
-SYSTEM_CONF_DIR=${SYSTEM_CONF_DIR:-"/etc/${ORG_SLUG}"}
-ROOT_CA_NAME=${ROOT_CA_NAME:-"${ORG_NAME} Root"}
-INT_CA_NAME=${INT_CA_NAME:-"${ORG_NAME} Intermediate mTLS"}
-DEVICE_TRUST_TOKEN_NAME=${DEVICE_TRUST_TOKEN_NAME:-"${ORG_NAME} Device Trust"}
+ROOT_DOMAIN="${ROOT_DOMAIN:-"fuhry.dev"}"
+DEFAULT_REGION="${DEFAULT_REGION:-"hq"}"
+DEFAULT_HOST_DOMAIN="${DEFAULT_HOST_DOMAIN:-"${DEFAULT_REGION}.${ROOT_DOMAIN}"}"
+SD_DOMAIN="${SD_DOMAIN:-"v.${ROOT_DOMAIN}"}"
+WEB_SERVICES_DOMAIN="${WEB_SERVICES_DOMAIN:-"${ROOT_DOMAIN}"}"
+MACHINES_HOST="${MACHINES_HOST:-"machines.${WEB_SERVICES_DOMAIN}"}"
+MACHINES_MQTT_TOPIC="${MACHINES_MQTT_TOPIC:-"machines/events"}"
+DBUS_PREFIX="${DBUS_PREFIX:-"dev.fuhry.runtime"}"
+DBUS_PATH="${DBUS_PATH:-"/${DBUS_PREFIX//\./\/}"}"
+ORG_NAME="${ORG_NAME:-"FooCorp"}"
+ORG_SLUG="${ORG_SLUG:-"runtime"}"
+SYSTEM_CONF_DIR="${SYSTEM_CONF_DIR:-"/etc/${ORG_SLUG}"}"
+ROOT_CA_NAME="${ROOT_CA_NAME:-"${ORG_NAME} Root"}"
+INT_CA_NAME="${INT_CA_NAME:-"${ORG_NAME} Intermediate mTLS"}"
+DEVICE_TRUST_TOKEN_NAME="${DEVICE_TRUST_TOKEN_NAME:-"${ORG_NAME} Device Trust"}"
+DEVICE_TRUST_PRINCIPAL="${DEVICE_TRUST_PRINCIPAL:-"devicetrust"}"
 TAG="$(cd "$basedir/.."; git describe HEAD 2>/dev/null)"
 VERSION="${TAG:-0.0.0+unset}"
 
@@ -41,6 +42,7 @@ vars=(
        ROOT_CA_NAME
        INT_CA_NAME
        DEVICE_TRUST_TOKEN_NAME
+       DEVICE_TRUST_PRINCIPAL
        VERSION
 )
 
index ea9af4b9c4d8087bd8ea1de9fe7d52d39b7d4d9c..a474c1f4025c06d5bfc2c04fa5c834263557f3de 100644 (file)
@@ -1,7 +1,7 @@
 load("@rules_go//go:def.bzl", "go_library")
 
 # Ignore this package in gazelle so constants_in.go is picked up by IDEs but not builds.
-# gazelle:ignore
+# gazelle:exclude constants_in.go
 
 genrule(
     name = "constants_go",
index dc549a66ec2a4f6ef42393c5fdfac94b3ac8a984..5bf193823d9c9fb517013fd64c126780d8fed601 100644 (file)
@@ -18,6 +18,7 @@ const (
        RootCAName           = "$ROOT_CA_NAME"
        IntCAName            = "$INT_CA_NAME"
        DeviceTrustTokenName = "$DEVICE_TRUST_TOKEN_NAME"
+       DeviceTrustPrincipal = "$DEVICE_TRUST_PRINCIPAL"
 
        Version = "$VERSION"
 )
index af97dbdfd2c230916fa03d6c64dea38c7779409f..08351c5593d34aba05197bb4ff24734846379630 100644 (file)
@@ -4,6 +4,7 @@ go_library(
     name = "mtls",
     srcs = [
         "identity.go",
+        "lazy_identity.go",
         "pkcs11.go",
         "provider_anonymous.go",
         "provider_file.go",
index 3e1bc997a63aa634a0679de013869cb8117e0f11..4f17681a9e95477c0c661e706e2a881f490a00a5 100644 (file)
@@ -9,6 +9,7 @@ import (
        "net/url"
        "os"
        "strings"
+       "time"
 )
 
 var oidSubjectAltName = asn1.ObjectIdentifier{2, 5, 29, 17}
@@ -153,3 +154,21 @@ func Fingerprint(cert *x509.Certificate, hash crypto.Hash) []byte {
        hasher.Write(cert.Raw)
        return hasher.Sum(dest)
 }
+
+func ValidNow(cert *x509.Certificate) error {
+       now := time.Now()
+
+       if now.Before(cert.NotBefore) {
+               return fmt.Errorf("certificate is not valid until %s (now = %s)",
+                       cert.NotBefore.UTC().Format(time.RFC3339),
+                       now.UTC().Format(time.RFC3339))
+       }
+
+       if now.After(cert.NotAfter) {
+               return fmt.Errorf("certificate expired at %s (now = %s)",
+                       cert.NotAfter.UTC().Format(time.RFC3339),
+                       now.UTC().Format(time.RFC3339))
+       }
+
+       return nil
+}
index 4631a2b75566d1e8ac870352acf01cdfe8d64ae5..94d4a71ed77b3667dc148320c721534c5d26597d 100644 (file)
@@ -2,11 +2,11 @@ package mtls
 
 import (
        "crypto/tls"
+       "errors"
        "flag"
        "fmt"
        "os/user"
        "strings"
-       "time"
 
        "go.fuhry.dev/runtime/mtls/certutil"
        "go.fuhry.dev/runtime/utils/log"
@@ -57,7 +57,18 @@ type substantiatedIdentity struct {
        CertificateProvider
 }
 
-type identityLoaderFunc func(name string) (CertificateProvider, error)
+var _ Identity = &substantiatedIdentity{}
+
+type stubIdentity struct {
+       *inaccessibleCertificate
+
+       name string
+       cls  PrincipalClass
+}
+
+var _ Identity = &stubIdentity{}
+
+type identityLoaderFunc func(cls PrincipalClass, name string) (CertificateProvider, error)
 
 type identityDriver struct {
        name string
@@ -66,6 +77,8 @@ type identityDriver struct {
 
 var identityDrivers []*identityDriver
 
+var ErrUnsupportedClass = errors.New("this driver does not support loading this identity class")
+
 func RegisterIdentityDriver(name string, load identityLoaderFunc) {
        driver := &identityDriver{
                name: name,
@@ -137,7 +150,7 @@ func identityIsValid(id CertificatePrimitive) bool {
                return false
        }
 
-       if time.Now().Before(cert.NotBefore) || time.Now().After(cert.NotAfter) {
+       if err := certutil.ValidNow(cert); err != nil {
                return false
        }
 
@@ -157,13 +170,6 @@ func identityEquals(a Identity, b Identity) bool {
        return a.Name() == b.Name() && a.Class() == b.Class()
 }
 
-type stubIdentity struct {
-       *inaccessibleCertificate
-
-       name string
-       cls  PrincipalClass
-}
-
 func (id *stubIdentity) Name() string {
        return id.name
 }
@@ -180,7 +186,7 @@ func (id *stubIdentity) IsValid() bool {
        return false
 }
 
-func ParseIdentity(identity string) Identity {
+func ParseIdentity(identity string) LazyIdentity {
        const (
                anonymousIdentityStr = "anonymous"
                userPrefix           = "user."
@@ -188,85 +194,41 @@ func ParseIdentity(identity string) Identity {
        )
 
        if identity == anonymousIdentityStr {
-               logger.V(3).Debugf("ParseIdentity(%q) -> Anonymous()", identity)
-               return Anonymous()
+               logger.V(3).Debugf("ParseIdentity(%q) -> Anonymous", identity)
+               return NewLazyIdentity(AnonymousPrincipal, anonymousIdentityStr)
        } else if strings.HasPrefix(identity, userPrefix) {
-               logger.V(3).Debugf("ParseIdentity(%q) -> NewUserIdentity(%q)", identity, strings.TrimPrefix(identity, userPrefix))
-               return NewUserIdentity(strings.TrimPrefix(identity, userPrefix))
+               logger.V(3).Debugf("ParseIdentity(%q) -> user:%s", identity, strings.TrimPrefix(identity, userPrefix))
+               return NewLazyIdentity(UserPrincipal, strings.TrimPrefix(identity, userPrefix))
        } else if strings.HasPrefix(identity, sslPrefix) {
-               logger.V(3).Debugf("ParseIdentity(%q) -> NewSSLCertificate(%q)", identity, strings.TrimPrefix(identity, sslPrefix))
-               return NewSSLCertificate(strings.TrimPrefix(identity, sslPrefix))
+               logger.V(3).Debugf("ParseIdentity(%q) -> ssl:%s", identity, strings.TrimPrefix(identity, sslPrefix))
+               return NewLazyIdentity(SSLCertificatePrincipal, strings.TrimPrefix(identity, sslPrefix))
        }
 
-       logger.V(3).Debugf("ParseIdentity(%q) -> NewServiceIdentity(%q)", identity, identity)
-       return NewServiceIdentity(identity)
+       logger.V(3).Debugf("ParseIdentity(%q) -> service:%s", identity, identity)
+       return NewLazyIdentity(ServicePrincipal, identity)
 }
 
-func NewServiceIdentity(service string) Identity {
-       for _, driver := range identityDrivers {
-               logger.V(1).Infof("trying driver %s to load service identity %s", driver.name, service)
-               identity, err := driver.load(service)
-
-               if err == nil {
-                       subst := &substantiatedIdentity{
-                               CertificateProvider: identity,
-                       }
-                       logger.V(2).Infof("driver %s reports it loaded identity %s:%s", driver.name, subst.Class().String(), subst.Name())
-
-                       if subst.Name() == service && subst.Class() == ServicePrincipal {
-                               logger.V(1).Noticef("successfully loaded %s(%s) with driver %s", subst.Class().String(), subst.Name(), driver.name)
-                               return subst
-                       } else {
-                               logger.V(2).Warnf(
-                                       "driver %s successfully loaded a certificate, but it doesn't match what "+
-                                               "we expected: class %s (expected)/%s (got); name %s (expected)/%s (got)",
-                                       driver.name, ServicePrincipal.String(), subst.Class().String(),
-                                       service, subst.Name())
-                       }
-               } else {
-                       logger.V(2).Warnf("driver %s failed to load service identity %s: %+v", driver.name, service, err)
-               }
-       }
+func NewRemoteServiceIdentity(service string) Identity {
+       return NewStubIdentity(ServicePrincipal, service)
+}
 
+func NewStubIdentity(cls PrincipalClass, principal string) Identity {
        return &stubIdentity{
                inaccessibleCertificate: &inaccessibleCertificate{},
 
-               name: service,
-               cls:  ServicePrincipal,
+               name: principal,
+               cls:  cls,
        }
 }
 
-func NewUserIdentity(username string) Identity {
-       for _, driver := range identityDrivers {
-               logger.V(1).Infof("trying driver %s to load user identity %s", driver.name, username)
-               identity, err := driver.load(username)
-               if err == nil {
-                       subst := &substantiatedIdentity{
-                               CertificateProvider: identity,
-                       }
-                       logger.V(2).Infof("driver %s reports it loaded identity %s:%s", driver.name, subst.Class().String(), subst.Name())
-
-                       if subst.Name() == username && subst.Class() == UserPrincipal {
-                               logger.V(1).Noticef("successfully loaded %s(%s) with driver %s", subst.Class().String(), subst.Name(), driver.name)
-                               return subst
-                       } else {
-                               logger.V(2).Warnf(
-                                       "driver %s successfully loaded a certificate, but it doesn't match what "+
-                                               "we expected: class %s (expected)/%s (got); name %s (expected)/%s (got)",
-                                       driver.name, ServicePrincipal.String(), subst.Class().String(),
-                                       username, subst.Name())
-                       }
-               } else {
-                       logger.V(2).Warnf("driver %s failed to load service identity %s: %+v", driver.name, username, err)
-               }
-       }
-
-       return &stubIdentity{
-               inaccessibleCertificate: &inaccessibleCertificate{},
+func NewServiceIdentity(service string) Identity {
+       id, _ := NewLazyIdentity(ServicePrincipal, service).Substantiate()
+       return id
+}
 
-               name: username,
-               cls:  UserPrincipal,
-       }
+func NewUserIdentity(username string) Identity {
+       id, _ := NewLazyIdentity(UserPrincipal, username).Substantiate()
+       return id
 }
 
 func NewDefaultUserIdentity() (Identity, error) {
@@ -275,7 +237,12 @@ func NewDefaultUserIdentity() (Identity, error) {
                return nil, err
        }
 
-       return NewUserIdentity(user.Username), nil
+       id := NewLazyIdentity(UserPrincipal, user.Username)
+       if subst, ok := id.Substantiate(); ok {
+               return subst, nil
+       } else {
+               return subst, ErrCertificateInaccessible
+       }
 }
 
 type substantiatedSslCertificate struct {
@@ -329,7 +296,12 @@ func init() {
        logger = log.Default().WithPrefix("mtls")
 }
 
+// SetDefaultIdentity sets the mtls id that is used when -mtls.id is not specified
+// on the command line.
 func SetDefaultIdentity(ident string) {
+       if flag.Parsed() {
+               panic("must call SetDefaultIdentity before flags are parsed")
+       }
        defaultMtlsIdentity = ident
 }
 
@@ -349,7 +321,7 @@ func DefaultIdentity() Identity {
                        log.Default().V(2).Debugf("couldn't load a user identity: err: %+v", err)
                }
 
-               return NewServiceIdentity(defaultDefaultIdentity)
+               return NewLazyIdentity(ServicePrincipal, defaultDefaultIdentity)
        }
 
        return ParseIdentity(defaultMtlsIdentity)
diff --git a/mtls/lazy_identity.go b/mtls/lazy_identity.go
new file mode 100644 (file)
index 0000000..26557fa
--- /dev/null
@@ -0,0 +1,168 @@
+package mtls
+
+import (
+       "context"
+       "crypto"
+       "crypto/tls"
+       "crypto/x509"
+       "sync"
+)
+
+// LazyIdentity represents an identity that is validated on first use, rather than at creation time.
+type LazyIdentity interface {
+       Identity
+
+       // Substantiate attempts to load the certificate and private key corresponding to the Identity.
+       //
+       // It always returns a non-nil Identity; if ok is true, the credentials are available for use.
+       // If ok is false, the returned Identity is invalid. It can be compared to other Identities, but
+       // calls to CertificateProvider functions always return `ErrCertificateInaccessible`.
+       Substantiate() (ident Identity, ok bool)
+}
+
+type lazyIdentity struct {
+       cp CertificateProvider
+
+       mu   sync.Mutex
+       name string
+       cls  PrincipalClass
+}
+
+var _ Identity = &lazyIdentity{}
+
+func NewLazyIdentity(cls PrincipalClass, name string) LazyIdentity {
+       return &lazyIdentity{name: name, cls: cls}
+}
+
+func (id *lazyIdentity) Name() string {
+       return id.name
+}
+
+func (id *lazyIdentity) Class() PrincipalClass {
+       return id.cls
+}
+
+func (id *lazyIdentity) Equals(other Identity) bool {
+       return identityEquals(id, other)
+}
+
+func (id *lazyIdentity) IsValid() bool {
+       id.tryLoad()
+
+       if id.cp != nil {
+               return identityIsValid(id.cp)
+       }
+
+       return false
+}
+
+func (id *lazyIdentity) tryLoad() {
+       id.mu.Lock()
+       defer id.mu.Unlock()
+
+       if id.cp != nil {
+               return
+       }
+
+       if id.cls == AnonymousPrincipal {
+               id.cp = Anonymous()
+               return
+       }
+       if id.cls == SSLCertificatePrincipal {
+               if cert := NewSSLCertificate(id.name); cert.IsValid() {
+                       id.cp = cert
+               }
+               return
+       }
+       for _, driver := range identityDrivers {
+               logger.V(1).Infof("trying driver %s to load %s identity %s", driver.name, id.cls, id.name)
+               if cert, err := driver.load(id.cls, id.name); err == nil {
+                       subst := &substantiatedIdentity{cert}
+                       if subst.Name() == id.name && subst.Class() == id.cls {
+                               logger.V(2).Infof("driver %s reports it loaded identity %s:%s", driver.name, id.cls.String(), id.name)
+                               id.cp = cert
+                               return
+                       } else {
+                               logger.V(2).Warnf(
+                                       "driver %s successfully loaded a certificate, but it doesn't match what "+
+                                               "we expected: class %s (expected)/%s (got); name %s (expected)/%s (got)",
+                                       driver.name, id.cls.String(), subst.Class().String(),
+                                       id.name, subst.Name())
+                       }
+               } else {
+                       logger.V(2).Warnf("driver %s failed to load %s identity %s: %+v", driver.name, id.cls.String(), id.name, err)
+               }
+       }
+}
+
+func (id *lazyIdentity) LeafCertificate() (*x509.Certificate, error) {
+       id.tryLoad()
+       if id.cp == nil {
+               return nil, ErrCertificateInaccessible
+       }
+       return id.cp.LeafCertificate()
+}
+
+func (id *lazyIdentity) IntermediateCertificates() ([]*x509.Certificate, error) {
+       id.tryLoad()
+       if id.cp == nil {
+               return nil, ErrCertificateInaccessible
+       }
+       return id.cp.IntermediateCertificates()
+}
+
+func (id *lazyIdentity) RootCertificate() (*x509.Certificate, error) {
+       id.tryLoad()
+       if id.cp == nil {
+               return nil, ErrCertificateInaccessible
+       }
+       return id.cp.RootCertificate()
+}
+
+func (id *lazyIdentity) PrivateKey() (crypto.PrivateKey, error) {
+       id.tryLoad()
+       if id.cp == nil {
+               return nil, ErrCertificateInaccessible
+       }
+       return id.cp.PrivateKey()
+}
+
+func (id *lazyIdentity) NewTlsCertificate() (*tls.Certificate, error) {
+       id.tryLoad()
+       if id.cp == nil {
+               return nil, ErrCertificateInaccessible
+       }
+       return id.cp.NewTlsCertificate()
+}
+
+func (id *lazyIdentity) NewDialContextFunc() DialContextFunc {
+       id.tryLoad()
+       if id.cp == nil {
+               return (&inaccessibleCertificate{}).NewDialContextFunc()
+       }
+       return id.cp.NewDialContextFunc()
+}
+
+func (id *lazyIdentity) TlsConfig(ctx context.Context) (*tls.Config, error) {
+       id.tryLoad()
+       if id.cp == nil {
+               return nil, ErrCertificateInaccessible
+       }
+       return id.cp.TlsConfig(ctx)
+}
+
+func (id *lazyIdentity) Substantiate() (Identity, bool) {
+       id.tryLoad()
+
+       if id.cp != nil {
+               return &substantiatedIdentity{
+                       CertificateProvider: id.cp,
+               }, true
+       }
+
+       return &stubIdentity{
+               inaccessibleCertificate: &inaccessibleCertificate{},
+               name:                    id.name,
+               cls:                     id.cls,
+       }, false
+}
index 4767051e9182e9c372c76c1310917ae59931532f..93123700b02a0cd7d3b4a3c69df931190a0e3876 100644 (file)
@@ -163,7 +163,7 @@ func newFileBackedCertificateFromBaseDir(mtlsRootPath string, serviceIdentity st
        }, nil
 }
 
-func LoadUserIdentityFromFilesystem() (*FileBackedCertificate, error) {
+func LoadUserIdentityFromFilesystem(username string) (*FileBackedCertificate, error) {
        fullChainPath, ok := os.LookupEnv("STEP_PERSONAL_CERTIFICATE")
        if !ok {
                return nil, fmt.Errorf("failed to get user certificate path from env STEP_PERSONAL_CERTIFICATE")
@@ -182,12 +182,35 @@ func LoadUserIdentityFromFilesystem() (*FileBackedCertificate, error) {
                }
        }
 
-       return &FileBackedCertificate{
+       cert := &FileBackedCertificate{
                LeafPath:          fullChainPath,
                IntermediatesPath: fullChainPath,
                PrivateKeyPath:    keyPath,
                RootPath:          rootPath,
-       }, nil
+       }
+
+       leaf, err := cert.LeafCertificate()
+       if err != nil {
+               return nil, fmt.Errorf("error reading leaf certificate: %v", err)
+       }
+
+       spiffe := certutil.SpiffeUrlFromCertificate(leaf)
+       if spiffe == nil {
+               return nil, fmt.Errorf("error getting spiffe URL from loaded leaf certificate %s", fullChainPath)
+       }
+
+       id, err := ParseRemoteIdentity(spiffe.String())
+       if err != nil {
+               return nil, fmt.Errorf("failure parsing spiffe URL %q from loaded leaf certificate %s: %v",
+                       spiffe, fullChainPath, err)
+       }
+
+       if id.Class != User || id.Principal != username {
+               return nil, fmt.Errorf("successfully loaded certificate at %s, but certificate contains the wrong credential: %s",
+                       id.ToSpiffe())
+       }
+
+       return cert, nil
 }
 
 func LoadSSLCertificateFromFilesystem(certName string) (*FileBackedCertificate, error) {
@@ -578,14 +601,23 @@ func init() {
        flag.StringVar(&sslCertsBaseDir, "tls.certs-dir", sslCertsBaseDir, "directory to look under for public-site SSL certificates (NOT mTLS certs)")
        flag.Func("mtls.certs-dir", "additional directory to search for mTLS certificates", appendMtlsCertificateDir)
 
-       RegisterIdentityDriver("file_service_global", func(serviceName string) (CertificateProvider, error) {
+       RegisterIdentityDriver("file_service_global", func(cls PrincipalClass, serviceName string) (CertificateProvider, error) {
+               if cls != ServicePrincipal {
+                       return nil, ErrUnsupportedClass
+               }
                return LoadServiceIdentityFromFilesystem(serviceName)
        })
-       RegisterIdentityDriver("file_service_csi_spiffe", func(serviceName string) (CertificateProvider, error) {
+       RegisterIdentityDriver("file_service_csi_spiffe", func(cls PrincipalClass, serviceName string) (CertificateProvider, error) {
+               if cls != ServicePrincipal {
+                       return nil, ErrUnsupportedClass
+               }
                return LoadServiceIdentityFromKubernetesCSIDriverSPIFFE(serviceName)
        })
-       RegisterIdentityDriver("file_user_home", func(_ string) (CertificateProvider, error) {
-               return LoadUserIdentityFromFilesystem()
+       RegisterIdentityDriver("file_user_home", func(cls PrincipalClass, username string) (CertificateProvider, error) {
+               if cls != UserPrincipal {
+                       return nil, ErrUnsupportedClass
+               }
+               return LoadUserIdentityFromFilesystem(username)
        })
        RegisterRootDriver("file_etc_mtls", func() (RootsPrimitive, error) {
                return defaultFileBackedRoots, nil
index fcc0129596e80147f76115fc990ceac3a0bd7ed6..bf9c0ec9a5c5b81a7ea97ccb6c73e6124839de9a 100644 (file)
@@ -66,7 +66,7 @@ func (kso *kcSignerOpts) HashFunc() crypto.Hash {
 
 var kcLogger log.Logger
 
-func NewCertificateFromMacKeychain(principal string) (CertificateProvider, error) {
+func NewCertificateFromMacKeychain(cls PrincipalClass, principal string) (CertificateProvider, error) {
        root, err := getMtlsRootFromMacKeychain()
        if err != nil {
                return nil, err
@@ -82,7 +82,7 @@ func NewCertificateFromMacKeychain(principal string) (CertificateProvider, error
                kcLogger.V(2).Debugf("loaded intermediate cert from keychain: %s", c.Subject.String())
        }
 
-       leaves, err := getLeafCertificatesFromKeychainMatchingPrincipal(ServicePrincipal, principal)
+       leaves, err := getLeafCertificatesFromKeychainMatchingPrincipal(cls, principal)
        if err != nil {
                return nil, err
        }
index 5f74a61ddf899f179426a8dc204116c59d98291c..48ad0fad6bd62011be815679b26a7acb8aa364ae 100644 (file)
@@ -10,6 +10,7 @@ import (
        "fmt"
        "path"
 
+       "go.fuhry.dev/runtime/constants"
        "go.fuhry.dev/runtime/mtls/certutil"
        "go.fuhry.dev/runtime/utils/log"
 )
@@ -112,7 +113,10 @@ func (c *TPMBackedCertificate) NewDialContextFunc() DialContextFunc {
 }
 
 func init() {
-       RegisterIdentityDriver("tpm2-pkcs11", func(_ string) (CertificateProvider, error) {
+       RegisterIdentityDriver("tpm2-pkcs11", func(cls PrincipalClass, serviceName string) (CertificateProvider, error) {
+               if cls != ServicePrincipal || serviceName != constants.DeviceTrustPrincipal {
+                       return nil, ErrUnsupportedClass
+               }
                return NewTPMBackedCertificate()
        })
 }
index 96845489dbf6026a52c8e41a6758c527f8cbcd86..4640f62f7b768879f2836dd5009db161d735d833 100644 (file)
@@ -4,6 +4,7 @@ import (
        "crypto/tls"
        "crypto/x509"
        "fmt"
+       "net/url"
        "regexp"
        "strings"
 
@@ -276,3 +277,26 @@ func ParseRemoteIdentity(name string) (*RemoteIdentity, error) {
 
        return nil, fmt.Errorf("cannot understand CN/SAN %q as an mTLS identity", name)
 }
+
+func (i *RemoteIdentity) ToSpiffe() *url.URL {
+       return &url.URL{
+               Scheme: "spiffe",
+               Host:   i.Domain,
+               Path:   fmt.Sprintf("%s/%s", i.Class.String(), i.Principal),
+       }
+}
+
+func (i *RemoteIdentity) ToDnsName() string {
+       switch i.Class {
+       case User:
+               return fmt.Sprintf("%s.%s.%s",
+                       i.Principal,
+                       i.Class.String(),
+                       i.Domain)
+       default:
+               return fmt.Sprintf("%s.%s.%s.mtls.internal",
+                       i.Principal,
+                       i.Class.String(),
+                       i.Domain)
+       }
+}