// type aliases: server
type Server = server.Server
+type ServerOption = server.ServerOption
// function aliases: common
var NewDefaultConnectionFactory = common.NewDefaultConnectionFactory
var WithConnectionFactory = client.WithConnectionFactory
var WithAddressProvider = client.WithAddressProvider
var WithStaticAddress = client.WithStaticAddress
+var WithDNSSRV = client.WithDNSSRV
var NewGrpcClient = client.NewGrpcClient
var PeerCertificate = server.PeerCertificate
var PeerIdentity = server.PeerIdentity
var NewHealthCheckServicer = server.NewHealthCheckServicer
+var SessionFromContext = server.SessionFromContext
+var WithTransport = server.WithTransport
--- /dev/null
+package client
+
+import (
+ "go.fuhry.dev/runtime/constants"
+ "go.fuhry.dev/runtime/mtls"
+ "go.fuhry.dev/runtime/net/dns"
+ "go.fuhry.dev/runtime/sd"
+ "go.fuhry.dev/runtime/utils/context"
+)
+
+type dnsSrvAddressProvider struct {
+ serverId mtls.Identity
+}
+
+var _ AddressProvider = &dnsSrvAddressProvider{}
+
+func (d *dnsSrvAddressProvider) GetAddrs(context.Context) ([]sd.ServiceAddress, error) {
+ results, err := dns.ResolveSRV(d.serverId.Name(), "grpc", constants.SDDomain)
+ if err != nil {
+ return nil, err
+ }
+
+ var sdAddrs []sd.ServiceAddress
+ for _, r := range results {
+ sdAddrs = append(sdAddrs, sd.ServiceAddress{
+ Hostname: r.Hostname,
+ IP4: r.Ip4.String(),
+ IP6: r.Ip6.String(),
+ Port: r.Port,
+ Protocol: sd.ProtocolGRPC,
+ Service: d.serverId.Name(),
+ })
+ }
+
+ return sdAddrs, nil
+}
import (
"context"
"crypto/tls"
+ "crypto/x509"
+ "errors"
"flag"
"fmt"
"math/rand"
hc HealthCheckServicer
}
+type ServerOption interface {
+ apply(*Server) error
+}
+
+type serverOption struct {
+ callback func(*Server) error
+}
+
+func (o *serverOption) apply(s *Server) error {
+ return o.callback(s)
+}
+
+func WithTransport(cf common.ConnectionFactory) ServerOption {
+ return &serverOption{
+ callback: func(s *Server) error {
+ s.connFac = cf
+ return nil
+ },
+ }
+}
+
var defaultPort *uint
func RandomPort() uint {
return uint(1025 + (uint(rand.Int()) % (65535 - 1025)))
}
-func NewGrpcServer(id mtls.Identity) (*Server, error) {
+func NewGrpcServer(id mtls.Identity, opts ...ServerOption) (*Server, error) {
if !flag.Parsed() {
panic("cannot start grpc services before flags are parsed")
}
- return NewGrpcServerWithPort(id, uint16(*defaultPort))
+ return NewGrpcServerWithPort(id, uint16(*defaultPort), opts...)
}
-func NewGrpcServerWithPort(id mtls.Identity, port uint16) (*Server, error) {
+func NewGrpcServerWithPort(id mtls.Identity, port uint16, opts ...ServerOption) (*Server, error) {
etcdc, err := sd.NewDefaultEtcdClient()
if err != nil {
panic(err)
verifier: cv,
log: log.WithPrefix(fmt.Sprintf("grpcServer:%s", id.Name())),
sessions: sessionsLru,
- connFac: common.NewDefaultConnectionFactory(),
hc: NewHealthCheckServicer(),
}
+ for _, opt := range opts {
+ if err := opt.apply(server); err != nil {
+ return nil, err
+ }
+ }
+
+ if server.connFac == nil {
+ server.connFac = common.NewDefaultConnectionFactory()
+ }
+
return server, nil
}
"go.fuhry.dev/runtime/utils/log"
)
+type SRVResult struct {
+ Priority uint16
+ Weight uint16
+ Port uint16
+ Hostname string
+ Ip4 net.IP
+ Ip6 net.IP
+}
+
var dnsCache *lru.Cache
var dnsCacheInit sync.Once
return ip4, ip6, nil
}
-func doDualStackQuery(hostname string) (*dns.Msg, error) {
- var msg *dns.Msg
+func ResolveSRV(service, protocol, domain string) ([]SRVResult, error) {
+ cc, err := dnsClientConfig()
+ if err != nil {
+ return nil, err
+ }
+
+ name := fmt.Sprintf("_%s._%s.%s.", service, protocol, domain)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ result := make(chan *dns.Msg, 0)
+ var wg sync.WaitGroup
+
+ query := makeQuery(name, dns.TypeSRV)
+ client := &dns.Client{}
+
+ for _, server := range cc.Servers {
+ wg.Add(1)
+ go (func() {
+ defer wg.Done()
+
+ log.WithPrefix("ResolveSRV").V(3).Debugf("attempting query against server %q: %v", server, query)
+ answer, _, err := client.Exchange(query, net.JoinHostPort(server, "53"))
+ if err == nil && answer != nil {
+ log.WithPrefix("ResolveSRV").V(3).Debugf("got answer %+v", answer)
+ select {
+ case result <- answer:
+ case <-ctx.Done():
+ }
+ } else {
+ log.WithPrefix("ResolveSRV").V(3).Debugf("error querying server %q: %v", server, err)
+ }
+ })()
+ }
+
+ go (func() {
+ wg.Wait()
+ close(result)
+ cancel()
+ })()
+
+ var answer *dns.Msg
+ select {
+ case answer = <-result:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+
+ if answer == nil {
+ return nil, fmt.Errorf("no answer received for query SRV %q", name)
+ }
+
+ var mu sync.Mutex
+ var results []SRVResult
+ var wg2 = sync.WaitGroup{}
+
+ for _, rr := range answer.Answer {
+ if srvRecord, ok := rr.(*dns.SRV); ok {
+ wg2.Add(1)
+ go (func() {
+ defer wg2.Done()
+
+ if ip4Str, ip6Str, err := ResolveDualStack(srvRecord.Target); err == nil {
+ ip4 := net.ParseIP(ip4Str)
+ ip6 := net.ParseIP(ip6Str)
+ answer := SRVResult{
+ Priority: srvRecord.Priority,
+ Weight: srvRecord.Weight,
+ Port: srvRecord.Port,
+ Hostname: srvRecord.Target,
+ Ip4: ip4,
+ Ip6: ip6,
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ results = append(results, answer)
+ }
+ })()
+ }
+ }
+
+ wg2.Wait()
+
+ if len(results) < 1 {
+ return nil, fmt.Errorf("no results returned for SRV %q", name)
+ }
+
+ return results, nil
+}
+
+func makeQuery(name string, qtype uint16) *dns.Msg {
+ return &dns.Msg{
+ MsgHdr: dns.MsgHdr{
+ Id: dns.Id(),
+ Opcode: dns.OpcodeQuery,
+ Response: false,
+ RecursionDesired: true,
+ AuthenticatedData: true,
+ },
+ Question: []dns.Question{
+ dns.Question{
+ Name: name,
+ Qtype: qtype,
+ Qclass: dns.ClassINET,
+ },
+ },
+ Extra: []dns.RR{
+ newEDNSCookie(),
+ },
+ }
+}
+
+func dnsClientConfig() (*dns.ClientConfig, error) {
// On Linux systems, the stub resolv.conf points at systemd-resolved which uses the hosts
// file. We don't want this - we need the system's addresses as described by the network's
// DNS server.
if err := fsutil.FileExistsAndIsReadable(resolvConfPath); err != nil {
resolvConfPath = "/etc/resolv.conf"
}
- cc, err := dns.ClientConfigFromFile(resolvConfPath)
+ return dns.ClientConfigFromFile(resolvConfPath)
+}
+
+func doDualStackQuery(hostname string) (*dns.Msg, error) {
+ var msg *dns.Msg
+
+ cc, err := dnsClientConfig()
if err != nil {
return nil, err
}
log.V(3).Debugf("cache miss, attempting dualstack dns query for hostname: %q", hostname)
log.V(3).Debugf("will resolve using DNS servers: %+v", cc.Servers)
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
- query := &dns.Msg{
- MsgHdr: dns.MsgHdr{
- Id: dns.Id(),
- Opcode: dns.OpcodeQuery,
- Response: false,
- RecursionDesired: true,
- AuthenticatedData: true,
- },
- Question: make([]dns.Question, 1),
- Extra: []dns.RR{
- newEDNSCookie(),
- },
- }
-
- query.Question[0] = dns.Question{
- Name: hostname,
- Qtype: qtype,
- Qclass: dns.ClassINET,
- }
+ query := makeQuery(hostname, qtype)
mu := &sync.Mutex{}
done := false