From a153be2f17ed2f27ca7e01be7bac59b4bebc5edd Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Wed, 19 Nov 2025 09:34:52 -0500 Subject: [PATCH] [mtls] make some types and functions public Make public the necessary functions and types to allow other packages to register identity and roots providers. - Types `CertificatePrimitive` and `RootsPrimitive` - Driver registration functions: `RegisterIdentityDriver`, `RegisterRootDriver` - `newDialContextFunc` -> `MakeDialContextFunc` - `newTlsCertificate` -> `MakeTlsCertificate` --- mtls/identity.go | 4 ++-- mtls/provider_anonymous.go | 4 ++-- mtls/provider_file.go | 22 +++++++++++----------- mtls/provider_interface.go | 12 ++++++------ mtls/provider_keychain_macos.go | 16 ++++++++-------- mtls/provider_shared.go | 4 ++-- mtls/provider_tpm2_pkcs11.go | 10 +++++----- mtls/verify_roots.go | 4 ++-- 8 files changed, 38 insertions(+), 38 deletions(-) diff --git a/mtls/identity.go b/mtls/identity.go index 2f527d4..4631a2b 100644 --- a/mtls/identity.go +++ b/mtls/identity.go @@ -66,7 +66,7 @@ type identityDriver struct { var identityDrivers []*identityDriver -func registerIdentityDriver(name string, load identityLoaderFunc) { +func RegisterIdentityDriver(name string, load identityLoaderFunc) { driver := &identityDriver{ name: name, load: load, @@ -131,7 +131,7 @@ func (id *substantiatedIdentity) IsValid() bool { return identityIsValid(id.CertificateProvider) } -func identityIsValid(id certificatePrimitive) bool { +func identityIsValid(id CertificatePrimitive) bool { cert, err := id.LeafCertificate() if err != nil { return false diff --git a/mtls/provider_anonymous.go b/mtls/provider_anonymous.go index 0a0ea57..20eea61 100644 --- a/mtls/provider_anonymous.go +++ b/mtls/provider_anonymous.go @@ -44,10 +44,10 @@ func (a *anonymousIdentity) PrivateKey() (crypto.PrivateKey, error) { } func (a *anonymousIdentity) NewDialContextFunc() DialContextFunc { - return newDialContextFunc(a) + return MakeDialContextFunc(a) } -func (a *anonymousIdentity) newTlsCertificate() (*tls.Certificate, error) { +func (a *anonymousIdentity) NewTlsCertificate() (*tls.Certificate, error) { return nil, nil } diff --git a/mtls/provider_file.go b/mtls/provider_file.go index 3b2952d..4767051 100644 --- a/mtls/provider_file.go +++ b/mtls/provider_file.go @@ -471,19 +471,19 @@ func (c *FileBackedCertificate) TlsConfig(ctx context.Context) (*tls.Config, err } func (c *FileBackedCertificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return c.newTlsCertificate() + return c.NewTlsCertificate() } func (c *FileBackedCertificate) GetClientCertificate(reqInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return c.newTlsCertificate() + return c.NewTlsCertificate() } -func (c *FileBackedCertificate) newTlsCertificate() (*tls.Certificate, error) { - return newTlsCertificate(c) +func (c *FileBackedCertificate) NewTlsCertificate() (*tls.Certificate, error) { + return MakeTlsCertificate(c) } func (c *FileBackedCertificate) NewDialContextFunc() DialContextFunc { - return newDialContextFunc(c) + return MakeDialContextFunc(c) } func (c *FileBackedCertificate) notifyEvent(filePath string, op fsnotify.Op) { @@ -578,22 +578,22 @@ func init() { flag.StringVar(&sslCertsBaseDir, "tls.certs-dir", sslCertsBaseDir, "directory to look under for public-site SSL certificates (NOT mTLS certs)") flag.Func("mtls.certs-dir", "additional directory to search for mTLS certificates", appendMtlsCertificateDir) - registerIdentityDriver("file_service_global", func(serviceName string) (CertificateProvider, error) { + RegisterIdentityDriver("file_service_global", func(serviceName string) (CertificateProvider, error) { return LoadServiceIdentityFromFilesystem(serviceName) }) - registerIdentityDriver("file_service_csi_spiffe", func(serviceName string) (CertificateProvider, error) { + RegisterIdentityDriver("file_service_csi_spiffe", func(serviceName string) (CertificateProvider, error) { return LoadServiceIdentityFromKubernetesCSIDriverSPIFFE(serviceName) }) - registerIdentityDriver("file_user_home", func(_ string) (CertificateProvider, error) { + RegisterIdentityDriver("file_user_home", func(_ string) (CertificateProvider, error) { return LoadUserIdentityFromFilesystem() }) - registerRootDriver("file_etc_mtls", func() (rootsPrimitive, error) { + RegisterRootDriver("file_etc_mtls", func() (RootsPrimitive, error) { return defaultFileBackedRoots, nil }) - registerRootDriver("file_csi_spiffe", func() (rootsPrimitive, error) { + RegisterRootDriver("file_csi_spiffe", func() (RootsPrimitive, error) { return csiSpiffeRoots, nil }) - registerRootDriver("file_csi_spiffe_altname", func() (rootsPrimitive, error) { + RegisterRootDriver("file_csi_spiffe_altname", func() (RootsPrimitive, error) { return csiSpiffeRootsAltName, nil }) } diff --git a/mtls/provider_interface.go b/mtls/provider_interface.go index f42f122..031909f 100644 --- a/mtls/provider_interface.go +++ b/mtls/provider_interface.go @@ -12,25 +12,25 @@ import ( type DialContextFunc func(context.Context, string, string) (net.Conn, error) type CertificateProvider interface { - certificatePrimitive + CertificatePrimitive TlsConfig(context.Context) (*tls.Config, error) NewDialContextFunc() DialContextFunc } -type rootsPrimitive interface { +type RootsPrimitive interface { RootCertificates() ([]*x509.Certificate, error) IntermediateCertificates() ([]*x509.Certificate, error) } -type certificatePrimitive interface { +type CertificatePrimitive interface { RootCertificate() (*x509.Certificate, error) IntermediateCertificates() ([]*x509.Certificate, error) LeafCertificate() (*x509.Certificate, error) PrivateKey() (crypto.PrivateKey, error) - newTlsCertificate() (*tls.Certificate, error) + NewTlsCertificate() (*tls.Certificate, error) } type inaccessibleCertificate struct{} @@ -53,7 +53,7 @@ func (c *inaccessibleCertificate) PrivateKey() (crypto.PrivateKey, error) { return nil, ErrCertificateInaccessible } -func (c *inaccessibleCertificate) newTlsCertificate() (*tls.Certificate, error) { +func (c *inaccessibleCertificate) NewTlsCertificate() (*tls.Certificate, error) { return nil, ErrCertificateInaccessible } @@ -62,5 +62,5 @@ func (c *inaccessibleCertificate) TlsConfig(ctx context.Context) (*tls.Config, e } func (c *inaccessibleCertificate) NewDialContextFunc() DialContextFunc { - return newDialContextFunc(c) + return MakeDialContextFunc(c) } diff --git a/mtls/provider_keychain_macos.go b/mtls/provider_keychain_macos.go index 0d174e1..fcc0129 100644 --- a/mtls/provider_keychain_macos.go +++ b/mtls/provider_keychain_macos.go @@ -27,7 +27,7 @@ import ( ) type macosKeychainCertificate struct { - certificatePrimitive + CertificatePrimitive ints []*x509.Certificate root *x509.Certificate @@ -118,12 +118,12 @@ func (c *macosKeychainCertificate) PrivateKey() (crypto.PrivateKey, error) { return c.pkey, nil } -func (c *macosKeychainCertificate) newTlsCertificate() (*tls.Certificate, error) { - return newTlsCertificate(c) +func (c *macosKeychainCertificate) NewTlsCertificate() (*tls.Certificate, error) { + return MakeTlsCertificate(c) } func (c *macosKeychainCertificate) NewDialContextFunc() DialContextFunc { - return newDialContextFunc(c) + return MakeDialContextFunc(c) } func (c *macosKeychainCertificate) TlsConfig(ctx context.Context) (*tls.Config, error) { @@ -136,11 +136,11 @@ func (c *macosKeychainCertificate) TlsConfig(ctx context.Context) (*tls.Config, } func (c *macosKeychainCertificate) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return c.newTlsCertificate() + return c.NewTlsCertificate() } func (c *macosKeychainCertificate) GetClientCertificate(reqInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return c.newTlsCertificate() + return c.NewTlsCertificate() } func getMtlsIntermediatesFromMacKeychain() ([]*x509.Certificate, error) { @@ -448,9 +448,9 @@ func (kcr *macosKeychainRoots) IntermediateCertificates() ([]*x509.Certificate, func init() { kcLogger = log.WithPrefix("mtls.macOSKeychain") - registerIdentityDriver("macos_keychain", NewCertificateFromMacKeychain) + RegisterIdentityDriver("macos_keychain", NewCertificateFromMacKeychain) - registerRootDriver("macos_keychain", func() (rootsPrimitive, error) { + RegisterRootDriver("macos_keychain", func() (RootsPrimitive, error) { return &macosKeychainRoots{}, nil }) } diff --git a/mtls/provider_shared.go b/mtls/provider_shared.go index 183e017..79da91c 100644 --- a/mtls/provider_shared.go +++ b/mtls/provider_shared.go @@ -6,7 +6,7 @@ import ( "net" ) -func newTlsCertificate(id certificatePrimitive) (*tls.Certificate, error) { +func MakeTlsCertificate(id CertificatePrimitive) (*tls.Certificate, error) { leafCertificate, err := id.LeafCertificate() if err != nil { return nil, err @@ -32,7 +32,7 @@ func newTlsCertificate(id certificatePrimitive) (*tls.Certificate, error) { }, nil } -func newDialContextFunc(id CertificateProvider) DialContextFunc { +func MakeDialContextFunc(id CertificateProvider) DialContextFunc { dcf := func(ctx context.Context, network, addr string) (net.Conn, error) { c, err := id.TlsConfig(ctx) if err != nil { diff --git a/mtls/provider_tpm2_pkcs11.go b/mtls/provider_tpm2_pkcs11.go index aafd50b..5f74a61 100644 --- a/mtls/provider_tpm2_pkcs11.go +++ b/mtls/provider_tpm2_pkcs11.go @@ -17,7 +17,7 @@ import ( var tpmLogger = log.WithPrefix("mtls.provider_tpm2_pkcs11") type TPMBackedCertificate struct { - certificatePrimitive + CertificatePrimitive p11 *p11 } @@ -91,11 +91,11 @@ func (c *TPMBackedCertificate) RootCertificate() (*x509.Certificate, error) { } func (c *TPMBackedCertificate) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return newTlsCertificate(c) + return MakeTlsCertificate(c) } func (c *TPMBackedCertificate) getClientCertificate(reqInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { - return newTlsCertificate(c) + return MakeTlsCertificate(c) } func (c *TPMBackedCertificate) TlsConfig(ctx context.Context) (*tls.Config, error) { @@ -108,11 +108,11 @@ func (c *TPMBackedCertificate) TlsConfig(ctx context.Context) (*tls.Config, erro } func (c *TPMBackedCertificate) NewDialContextFunc() DialContextFunc { - return newDialContextFunc(c) + return MakeDialContextFunc(c) } func init() { - registerIdentityDriver("tpm2-pkcs11", func(_ string) (CertificateProvider, error) { + RegisterIdentityDriver("tpm2-pkcs11", func(_ string) (CertificateProvider, error) { return NewTPMBackedCertificate() }) } diff --git a/mtls/verify_roots.go b/mtls/verify_roots.go index 37f4e28..b52159f 100644 --- a/mtls/verify_roots.go +++ b/mtls/verify_roots.go @@ -14,10 +14,10 @@ type tlsVerifyPeerCertificatesFunc = func(rawCerts [][]byte, verifiedChains [][] type rootDriver struct { name string - load func() (rootsPrimitive, error) + load func() (RootsPrimitive, error) } -func registerRootDriver(name string, load func() (rootsPrimitive, error)) { +func RegisterRootDriver(name string, load func() (RootsPrimitive, error)) { if rootsDrivers == nil { rootsDrivers = make([]*rootDriver, 0) } -- 2.50.1