From: Dan Fuhry Date: Wed, 19 Nov 2025 14:30:24 +0000 (-0500) Subject: [grpc] adopt Option pattern, add DNS SRV support X-Git-Url: https://go.fuhry.dev/?a=commitdiff_plain;h=1be16cb21752653a4a94e734109d43e7d04e7e5e;p=runtime.git [grpc] adopt Option pattern, add DNS SRV support - Allow gRPC server and client factories to accept variadic Options - Add `ClientOption`: `WithDNSSRV`, which uses DNS SRV queries instead of SD - Add `ServerOption`: `WithTransport` (overrides `-grpc.transport` flag) - Force ephs server to use QUIC transport as QUIC is hardcoded into the client library --- diff --git a/cmd/ephs_server/main.go b/cmd/ephs_server/main.go index 2995fa3..288696e 100644 --- a/cmd/ephs_server/main.go +++ b/cmd/ephs_server/main.go @@ -23,7 +23,7 @@ func main() { flag.Parse() serverIdentity := mtls.DefaultIdentity() - s, err := grpc.NewGrpcServer(serverIdentity) + s, err := grpc.NewGrpcServer(serverIdentity, grpc.WithTransport(&grpc.QUICConnectionFactory{})) if err != nil { panic(err) } diff --git a/grpc/imports.go b/grpc/imports.go index 5fdef20..59a11b5 100644 --- a/grpc/imports.go +++ b/grpc/imports.go @@ -20,6 +20,7 @@ type ClientConn = client.ClientConn // type aliases: server type Server = server.Server +type ServerOption = server.ServerOption // function aliases: common var NewDefaultConnectionFactory = common.NewDefaultConnectionFactory @@ -29,6 +30,7 @@ var RegisterTransport = common.RegisterTransport var WithConnectionFactory = client.WithConnectionFactory var WithAddressProvider = client.WithAddressProvider var WithStaticAddress = client.WithStaticAddress +var WithDNSSRV = client.WithDNSSRV var NewGrpcClient = client.NewGrpcClient @@ -39,3 +41,5 @@ var NewGrpcServerWithPort = server.NewGrpcServerWithPort var PeerCertificate = server.PeerCertificate var PeerIdentity = server.PeerIdentity var NewHealthCheckServicer = server.NewHealthCheckServicer +var SessionFromContext = server.SessionFromContext +var WithTransport = server.WithTransport diff --git a/grpc/internal/client/BUILD.bazel b/grpc/internal/client/BUILD.bazel index 42d11ce..7e89d3a 100644 --- a/grpc/internal/client/BUILD.bazel +++ b/grpc/internal/client/BUILD.bazel @@ -9,12 +9,18 @@ package( go_library( name = "client", - srcs = ["client.go"], + srcs = [ + "client.go", + "dns_srv.go", + ], importpath = "go.fuhry.dev/runtime/grpc/internal/client", deps = [ + "//constants", "//grpc/internal/common", "//mtls", + "//net/dns", "//sd", + "//utils/context", "@org_golang_google_grpc//:grpc", ], ) diff --git a/grpc/internal/client/client.go b/grpc/internal/client/client.go index bc54369..60d0edf 100644 --- a/grpc/internal/client/client.go +++ b/grpc/internal/client/client.go @@ -95,6 +95,19 @@ func WithStaticAddress(addresses ...*net.TCPAddr) ClientOption { } } +func WithDNSSRV() ClientOption { + return &clientOption{ + f: func(c *client) error { + ap := &dnsSrvAddressProvider{ + serverId: c.serverId, + } + + c.watcher = ap + return nil + }, + } +} + func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity, opts ...ClientOption) (Client, error) { cl := &client{ ctx: ctx, diff --git a/grpc/internal/client/dns_srv.go b/grpc/internal/client/dns_srv.go new file mode 100644 index 0000000..00241e9 --- /dev/null +++ b/grpc/internal/client/dns_srv.go @@ -0,0 +1,36 @@ +package client + +import ( + "go.fuhry.dev/runtime/constants" + "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/net/dns" + "go.fuhry.dev/runtime/sd" + "go.fuhry.dev/runtime/utils/context" +) + +type dnsSrvAddressProvider struct { + serverId mtls.Identity +} + +var _ AddressProvider = &dnsSrvAddressProvider{} + +func (d *dnsSrvAddressProvider) GetAddrs(context.Context) ([]sd.ServiceAddress, error) { + results, err := dns.ResolveSRV(d.serverId.Name(), "grpc", constants.SDDomain) + if err != nil { + return nil, err + } + + var sdAddrs []sd.ServiceAddress + for _, r := range results { + sdAddrs = append(sdAddrs, sd.ServiceAddress{ + Hostname: r.Hostname, + IP4: r.Ip4.String(), + IP6: r.Ip6.String(), + Port: r.Port, + Protocol: sd.ProtocolGRPC, + Service: d.serverId.Name(), + }) + } + + return sdAddrs, nil +} diff --git a/grpc/internal/server/server.go b/grpc/internal/server/server.go index d5a3739..93881d3 100644 --- a/grpc/internal/server/server.go +++ b/grpc/internal/server/server.go @@ -3,6 +3,8 @@ package server import ( "context" "crypto/tls" + "crypto/x509" + "errors" "flag" "fmt" "math/rand" @@ -38,21 +40,42 @@ type Server struct { hc HealthCheckServicer } +type ServerOption interface { + apply(*Server) error +} + +type serverOption struct { + callback func(*Server) error +} + +func (o *serverOption) apply(s *Server) error { + return o.callback(s) +} + +func WithTransport(cf common.ConnectionFactory) ServerOption { + return &serverOption{ + callback: func(s *Server) error { + s.connFac = cf + return nil + }, + } +} + var defaultPort *uint func RandomPort() uint { return uint(1025 + (uint(rand.Int()) % (65535 - 1025))) } -func NewGrpcServer(id mtls.Identity) (*Server, error) { +func NewGrpcServer(id mtls.Identity, opts ...ServerOption) (*Server, error) { if !flag.Parsed() { panic("cannot start grpc services before flags are parsed") } - return NewGrpcServerWithPort(id, uint16(*defaultPort)) + return NewGrpcServerWithPort(id, uint16(*defaultPort), opts...) } -func NewGrpcServerWithPort(id mtls.Identity, port uint16) (*Server, error) { +func NewGrpcServerWithPort(id mtls.Identity, port uint16, opts ...ServerOption) (*Server, error) { etcdc, err := sd.NewDefaultEtcdClient() if err != nil { panic(err) @@ -87,10 +110,19 @@ func NewGrpcServerWithPort(id mtls.Identity, port uint16) (*Server, error) { verifier: cv, log: log.WithPrefix(fmt.Sprintf("grpcServer:%s", id.Name())), sessions: sessionsLru, - connFac: common.NewDefaultConnectionFactory(), hc: NewHealthCheckServicer(), } + for _, opt := range opts { + if err := opt.apply(server); err != nil { + return nil, err + } + } + + if server.connFac == nil { + server.connFac = common.NewDefaultConnectionFactory() + } + return server, nil } diff --git a/net/dns/dns_cache.go b/net/dns/dns_cache.go index 6f63f2a..bb024af 100644 --- a/net/dns/dns_cache.go +++ b/net/dns/dns_cache.go @@ -16,6 +16,15 @@ import ( "go.fuhry.dev/runtime/utils/log" ) +type SRVResult struct { + Priority uint16 + Weight uint16 + Port uint16 + Hostname string + Ip4 net.IP + Ip6 net.IP +} + var dnsCache *lru.Cache var dnsCacheInit sync.Once @@ -105,9 +114,120 @@ func ResolveDualStack(name string) (string, string, error) { return ip4, ip6, nil } -func doDualStackQuery(hostname string) (*dns.Msg, error) { - var msg *dns.Msg +func ResolveSRV(service, protocol, domain string) ([]SRVResult, error) { + cc, err := dnsClientConfig() + if err != nil { + return nil, err + } + + name := fmt.Sprintf("_%s._%s.%s.", service, protocol, domain) + + ctx, cancel := context.WithCancel(context.Background()) + result := make(chan *dns.Msg, 0) + var wg sync.WaitGroup + + query := makeQuery(name, dns.TypeSRV) + client := &dns.Client{} + + for _, server := range cc.Servers { + wg.Add(1) + go (func() { + defer wg.Done() + + log.WithPrefix("ResolveSRV").V(3).Debugf("attempting query against server %q: %v", server, query) + answer, _, err := client.Exchange(query, net.JoinHostPort(server, "53")) + if err == nil && answer != nil { + log.WithPrefix("ResolveSRV").V(3).Debugf("got answer %+v", answer) + select { + case result <- answer: + case <-ctx.Done(): + } + } else { + log.WithPrefix("ResolveSRV").V(3).Debugf("error querying server %q: %v", server, err) + } + })() + } + + go (func() { + wg.Wait() + close(result) + cancel() + })() + + var answer *dns.Msg + select { + case answer = <-result: + case <-ctx.Done(): + return nil, ctx.Err() + } + + if answer == nil { + return nil, fmt.Errorf("no answer received for query SRV %q", name) + } + + var mu sync.Mutex + var results []SRVResult + var wg2 = sync.WaitGroup{} + + for _, rr := range answer.Answer { + if srvRecord, ok := rr.(*dns.SRV); ok { + wg2.Add(1) + go (func() { + defer wg2.Done() + + if ip4Str, ip6Str, err := ResolveDualStack(srvRecord.Target); err == nil { + ip4 := net.ParseIP(ip4Str) + ip6 := net.ParseIP(ip6Str) + answer := SRVResult{ + Priority: srvRecord.Priority, + Weight: srvRecord.Weight, + Port: srvRecord.Port, + Hostname: srvRecord.Target, + Ip4: ip4, + Ip6: ip6, + } + + mu.Lock() + defer mu.Unlock() + results = append(results, answer) + } + })() + } + } + + wg2.Wait() + + if len(results) < 1 { + return nil, fmt.Errorf("no results returned for SRV %q", name) + } + + return results, nil +} + +func makeQuery(name string, qtype uint16) *dns.Msg { + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + Opcode: dns.OpcodeQuery, + Response: false, + RecursionDesired: true, + AuthenticatedData: true, + }, + Question: []dns.Question{ + dns.Question{ + Name: name, + Qtype: qtype, + Qclass: dns.ClassINET, + }, + }, + Extra: []dns.RR{ + newEDNSCookie(), + }, + } +} + +func dnsClientConfig() (*dns.ClientConfig, error) { // On Linux systems, the stub resolv.conf points at systemd-resolved which uses the hosts // file. We don't want this - we need the system's addresses as described by the network's // DNS server. @@ -115,7 +235,13 @@ func doDualStackQuery(hostname string) (*dns.Msg, error) { if err := fsutil.FileExistsAndIsReadable(resolvConfPath); err != nil { resolvConfPath = "/etc/resolv.conf" } - cc, err := dns.ClientConfigFromFile(resolvConfPath) + return dns.ClientConfigFromFile(resolvConfPath) +} + +func doDualStackQuery(hostname string) (*dns.Msg, error) { + var msg *dns.Msg + + cc, err := dnsClientConfig() if err != nil { return nil, err } @@ -123,25 +249,7 @@ func doDualStackQuery(hostname string) (*dns.Msg, error) { log.V(3).Debugf("cache miss, attempting dualstack dns query for hostname: %q", hostname) log.V(3).Debugf("will resolve using DNS servers: %+v", cc.Servers) for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} { - query := &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: dns.Id(), - Opcode: dns.OpcodeQuery, - Response: false, - RecursionDesired: true, - AuthenticatedData: true, - }, - Question: make([]dns.Question, 1), - Extra: []dns.RR{ - newEDNSCookie(), - }, - } - - query.Question[0] = dns.Question{ - Name: hostname, - Qtype: qtype, - Qclass: dns.ClassINET, - } + query := makeQuery(hostname, qtype) mu := &sync.Mutex{} done := false