]> go.fuhry.dev Git - runtime.git/commitdiff
[grpc] adopt Option pattern, add DNS SRV support
authorDan Fuhry <dan@fuhry.com>
Wed, 19 Nov 2025 14:30:24 +0000 (09:30 -0500)
committerDan Fuhry <dan@fuhry.com>
Wed, 19 Nov 2025 14:33:18 +0000 (09:33 -0500)
- Allow gRPC server and client factories to accept variadic Options
- Add `ClientOption`: `WithDNSSRV`, which uses DNS SRV queries instead of SD
- Add `ServerOption`: `WithTransport` (overrides `-grpc.transport` flag)
- Force ephs server to use QUIC transport as QUIC is hardcoded into the client library

cmd/ephs_server/main.go
grpc/imports.go
grpc/internal/client/BUILD.bazel
grpc/internal/client/client.go
grpc/internal/client/dns_srv.go [new file with mode: 0644]
grpc/internal/server/server.go
net/dns/dns_cache.go

index 2995fa3728dd415f765cf43896224779745f2c89..288696e1bd76801f51d78a9c8156a0be455c4c2e 100644 (file)
@@ -23,7 +23,7 @@ func main() {
        flag.Parse()
 
        serverIdentity := mtls.DefaultIdentity()
-       s, err := grpc.NewGrpcServer(serverIdentity)
+       s, err := grpc.NewGrpcServer(serverIdentity, grpc.WithTransport(&grpc.QUICConnectionFactory{}))
        if err != nil {
                panic(err)
        }
index 5fdef20f5aa568259239b777bcd617e55b9e8e9c..59a11b55f9b23c2f63e9b1b8cf26c143957270d5 100644 (file)
@@ -20,6 +20,7 @@ type ClientConn = client.ClientConn
 
 // type aliases: server
 type Server = server.Server
+type ServerOption = server.ServerOption
 
 // function aliases: common
 var NewDefaultConnectionFactory = common.NewDefaultConnectionFactory
@@ -29,6 +30,7 @@ var RegisterTransport = common.RegisterTransport
 var WithConnectionFactory = client.WithConnectionFactory
 var WithAddressProvider = client.WithAddressProvider
 var WithStaticAddress = client.WithStaticAddress
+var WithDNSSRV = client.WithDNSSRV
 
 var NewGrpcClient = client.NewGrpcClient
 
@@ -39,3 +41,5 @@ var NewGrpcServerWithPort = server.NewGrpcServerWithPort
 var PeerCertificate = server.PeerCertificate
 var PeerIdentity = server.PeerIdentity
 var NewHealthCheckServicer = server.NewHealthCheckServicer
+var SessionFromContext = server.SessionFromContext
+var WithTransport = server.WithTransport
index 42d11cee7436a5c4a9c1eb9c930bfc091e963740..7e89d3a1a4d8922e86a96af7c3e4e11d0b7a23de 100644 (file)
@@ -9,12 +9,18 @@ package(
 
 go_library(
     name = "client",
-    srcs = ["client.go"],
+    srcs = [
+        "client.go",
+        "dns_srv.go",
+    ],
     importpath = "go.fuhry.dev/runtime/grpc/internal/client",
     deps = [
+        "//constants",
         "//grpc/internal/common",
         "//mtls",
+        "//net/dns",
         "//sd",
+        "//utils/context",
         "@org_golang_google_grpc//:grpc",
     ],
 )
index bc5436960730346b2c8215f6eac90a337d610bd5..60d0edf9f12d4bdb071842636e5b026188d4890f 100644 (file)
@@ -95,6 +95,19 @@ func WithStaticAddress(addresses ...*net.TCPAddr) ClientOption {
        }
 }
 
+func WithDNSSRV() ClientOption {
+       return &clientOption{
+               f: func(c *client) error {
+                       ap := &dnsSrvAddressProvider{
+                               serverId: c.serverId,
+                       }
+
+                       c.watcher = ap
+                       return nil
+               },
+       }
+}
+
 func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity, opts ...ClientOption) (Client, error) {
        cl := &client{
                ctx:      ctx,
diff --git a/grpc/internal/client/dns_srv.go b/grpc/internal/client/dns_srv.go
new file mode 100644 (file)
index 0000000..00241e9
--- /dev/null
@@ -0,0 +1,36 @@
+package client
+
+import (
+       "go.fuhry.dev/runtime/constants"
+       "go.fuhry.dev/runtime/mtls"
+       "go.fuhry.dev/runtime/net/dns"
+       "go.fuhry.dev/runtime/sd"
+       "go.fuhry.dev/runtime/utils/context"
+)
+
+type dnsSrvAddressProvider struct {
+       serverId mtls.Identity
+}
+
+var _ AddressProvider = &dnsSrvAddressProvider{}
+
+func (d *dnsSrvAddressProvider) GetAddrs(context.Context) ([]sd.ServiceAddress, error) {
+       results, err := dns.ResolveSRV(d.serverId.Name(), "grpc", constants.SDDomain)
+       if err != nil {
+               return nil, err
+       }
+
+       var sdAddrs []sd.ServiceAddress
+       for _, r := range results {
+               sdAddrs = append(sdAddrs, sd.ServiceAddress{
+                       Hostname: r.Hostname,
+                       IP4:      r.Ip4.String(),
+                       IP6:      r.Ip6.String(),
+                       Port:     r.Port,
+                       Protocol: sd.ProtocolGRPC,
+                       Service:  d.serverId.Name(),
+               })
+       }
+
+       return sdAddrs, nil
+}
index d5a3739f2a62a76ebb882f9e5690eed28affa947..93881d34a5670794d73f7deece3159f13cf98ec3 100644 (file)
@@ -3,6 +3,8 @@ package server
 import (
        "context"
        "crypto/tls"
+       "crypto/x509"
+       "errors"
        "flag"
        "fmt"
        "math/rand"
@@ -38,21 +40,42 @@ type Server struct {
        hc         HealthCheckServicer
 }
 
+type ServerOption interface {
+       apply(*Server) error
+}
+
+type serverOption struct {
+       callback func(*Server) error
+}
+
+func (o *serverOption) apply(s *Server) error {
+       return o.callback(s)
+}
+
+func WithTransport(cf common.ConnectionFactory) ServerOption {
+       return &serverOption{
+               callback: func(s *Server) error {
+                       s.connFac = cf
+                       return nil
+               },
+       }
+}
+
 var defaultPort *uint
 
 func RandomPort() uint {
        return uint(1025 + (uint(rand.Int()) % (65535 - 1025)))
 }
 
-func NewGrpcServer(id mtls.Identity) (*Server, error) {
+func NewGrpcServer(id mtls.Identity, opts ...ServerOption) (*Server, error) {
        if !flag.Parsed() {
                panic("cannot start grpc services before flags are parsed")
        }
 
-       return NewGrpcServerWithPort(id, uint16(*defaultPort))
+       return NewGrpcServerWithPort(id, uint16(*defaultPort), opts...)
 }
 
-func NewGrpcServerWithPort(id mtls.Identity, port uint16) (*Server, error) {
+func NewGrpcServerWithPort(id mtls.Identity, port uint16, opts ...ServerOption) (*Server, error) {
        etcdc, err := sd.NewDefaultEtcdClient()
        if err != nil {
                panic(err)
@@ -87,10 +110,19 @@ func NewGrpcServerWithPort(id mtls.Identity, port uint16) (*Server, error) {
                verifier:  cv,
                log:       log.WithPrefix(fmt.Sprintf("grpcServer:%s", id.Name())),
                sessions:  sessionsLru,
-               connFac:   common.NewDefaultConnectionFactory(),
                hc:        NewHealthCheckServicer(),
        }
 
+       for _, opt := range opts {
+               if err := opt.apply(server); err != nil {
+                       return nil, err
+               }
+       }
+
+       if server.connFac == nil {
+               server.connFac = common.NewDefaultConnectionFactory()
+       }
+
        return server, nil
 }
 
index 6f63f2a6f14bce7b0add453854f1196b34859c4f..bb024af5fd160c4bd47d37ef4ce77c8502b6c4d4 100644 (file)
@@ -16,6 +16,15 @@ import (
        "go.fuhry.dev/runtime/utils/log"
 )
 
+type SRVResult struct {
+       Priority uint16
+       Weight   uint16
+       Port     uint16
+       Hostname string
+       Ip4      net.IP
+       Ip6      net.IP
+}
+
 var dnsCache *lru.Cache
 var dnsCacheInit sync.Once
 
@@ -105,9 +114,120 @@ func ResolveDualStack(name string) (string, string, error) {
        return ip4, ip6, nil
 }
 
-func doDualStackQuery(hostname string) (*dns.Msg, error) {
-       var msg *dns.Msg
+func ResolveSRV(service, protocol, domain string) ([]SRVResult, error) {
+       cc, err := dnsClientConfig()
+       if err != nil {
+               return nil, err
+       }
+
+       name := fmt.Sprintf("_%s._%s.%s.", service, protocol, domain)
+
+       ctx, cancel := context.WithCancel(context.Background())
+       result := make(chan *dns.Msg, 0)
+       var wg sync.WaitGroup
+
+       query := makeQuery(name, dns.TypeSRV)
+       client := &dns.Client{}
+
+       for _, server := range cc.Servers {
+               wg.Add(1)
+               go (func() {
+                       defer wg.Done()
+
+                       log.WithPrefix("ResolveSRV").V(3).Debugf("attempting query against server %q: %v", server, query)
+                       answer, _, err := client.Exchange(query, net.JoinHostPort(server, "53"))
+                       if err == nil && answer != nil {
+                               log.WithPrefix("ResolveSRV").V(3).Debugf("got answer %+v", answer)
+                               select {
+                               case result <- answer:
+                               case <-ctx.Done():
+                               }
+                       } else {
+                               log.WithPrefix("ResolveSRV").V(3).Debugf("error querying server %q: %v", server, err)
+                       }
+               })()
+       }
+
+       go (func() {
+               wg.Wait()
+               close(result)
+               cancel()
+       })()
+
+       var answer *dns.Msg
+       select {
+       case answer = <-result:
+       case <-ctx.Done():
+               return nil, ctx.Err()
+       }
+
+       if answer == nil {
+               return nil, fmt.Errorf("no answer received for query SRV %q", name)
+       }
+
+       var mu sync.Mutex
+       var results []SRVResult
+       var wg2 = sync.WaitGroup{}
+
+       for _, rr := range answer.Answer {
+               if srvRecord, ok := rr.(*dns.SRV); ok {
+                       wg2.Add(1)
+                       go (func() {
+                               defer wg2.Done()
+
+                               if ip4Str, ip6Str, err := ResolveDualStack(srvRecord.Target); err == nil {
+                                       ip4 := net.ParseIP(ip4Str)
+                                       ip6 := net.ParseIP(ip6Str)
 
+                                       answer := SRVResult{
+                                               Priority: srvRecord.Priority,
+                                               Weight:   srvRecord.Weight,
+                                               Port:     srvRecord.Port,
+                                               Hostname: srvRecord.Target,
+                                               Ip4:      ip4,
+                                               Ip6:      ip6,
+                                       }
+
+                                       mu.Lock()
+                                       defer mu.Unlock()
+                                       results = append(results, answer)
+                               }
+                       })()
+               }
+       }
+
+       wg2.Wait()
+
+       if len(results) < 1 {
+               return nil, fmt.Errorf("no results returned for SRV %q", name)
+       }
+
+       return results, nil
+}
+
+func makeQuery(name string, qtype uint16) *dns.Msg {
+       return &dns.Msg{
+               MsgHdr: dns.MsgHdr{
+                       Id:                dns.Id(),
+                       Opcode:            dns.OpcodeQuery,
+                       Response:          false,
+                       RecursionDesired:  true,
+                       AuthenticatedData: true,
+               },
+               Question: []dns.Question{
+                       dns.Question{
+                               Name:   name,
+                               Qtype:  qtype,
+                               Qclass: dns.ClassINET,
+                       },
+               },
+               Extra: []dns.RR{
+                       newEDNSCookie(),
+               },
+       }
+}
+
+func dnsClientConfig() (*dns.ClientConfig, error) {
        // On Linux systems, the stub resolv.conf points at systemd-resolved which uses the hosts
        // file. We don't want this - we need the system's addresses as described by the network's
        // DNS server.
@@ -115,7 +235,13 @@ func doDualStackQuery(hostname string) (*dns.Msg, error) {
        if err := fsutil.FileExistsAndIsReadable(resolvConfPath); err != nil {
                resolvConfPath = "/etc/resolv.conf"
        }
-       cc, err := dns.ClientConfigFromFile(resolvConfPath)
+       return dns.ClientConfigFromFile(resolvConfPath)
+}
+
+func doDualStackQuery(hostname string) (*dns.Msg, error) {
+       var msg *dns.Msg
+
+       cc, err := dnsClientConfig()
        if err != nil {
                return nil, err
        }
@@ -123,25 +249,7 @@ func doDualStackQuery(hostname string) (*dns.Msg, error) {
        log.V(3).Debugf("cache miss, attempting dualstack dns query for hostname: %q", hostname)
        log.V(3).Debugf("will resolve using DNS servers: %+v", cc.Servers)
        for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
-               query := &dns.Msg{
-                       MsgHdr: dns.MsgHdr{
-                               Id:                dns.Id(),
-                               Opcode:            dns.OpcodeQuery,
-                               Response:          false,
-                               RecursionDesired:  true,
-                               AuthenticatedData: true,
-                       },
-                       Question: make([]dns.Question, 1),
-                       Extra: []dns.RR{
-                               newEDNSCookie(),
-                       },
-               }
-
-               query.Question[0] = dns.Question{
-                       Name:   hostname,
-                       Qtype:  qtype,
-                       Qclass: dns.ClassINET,
-               }
+               query := makeQuery(hostname, qtype)
 
                mu := &sync.Mutex{}
                done := false