From 0f3768d83cdb773a00838b7f7304e61355ad8604 Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Tue, 27 Feb 2024 16:03:31 -0500 Subject: [PATCH] sase/happy_eyeballs: use HE for client dial; add flags to block v4 or v6 connections --- sase/client.go | 1 + sase/happy_eyeballs.go | 77 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/sase/client.go b/sase/client.go index 0ed4f39..6191bbd 100644 --- a/sase/client.go +++ b/sase/client.go @@ -28,6 +28,7 @@ func NewClient(id mtls.Identity) (Client, error) { } tlsConfig.RootCAs = nil wsClient := &websocket.Dialer{ + NetDialContext: dialHappyEyeballsNameContext, TLSClientConfig: tlsConfig, } diff --git a/sase/happy_eyeballs.go b/sase/happy_eyeballs.go index 9eb789a..5ebd07e 100644 --- a/sase/happy_eyeballs.go +++ b/sase/happy_eyeballs.go @@ -2,13 +2,69 @@ package sase import ( "context" + "flag" "fmt" "net" + "strconv" "strings" "sync" + + "go.fuhry.dev/runtime/net/dns" + "go.fuhry.dev/runtime/utils/log" +) + +var logger = log.WithPrefix("sase/happy_eyeballs") +var ( + ipv4Enable, + ipv6Enable bool ) +func init() { + flag.BoolVar(&ipv4Enable, "happy-eyeballs.ipv4", true, "allow happy eyeballs to dial IPv4 addresses") + flag.BoolVar(&ipv6Enable, "happy-eyeballs.ipv6", true, "allow happy eyeballs to dial IPv6 addresses") +} + +func dialHappyEyeballsNameContext(ctx context.Context, network, addr string) (net.Conn, error) { + logger.V(1).Debugf("dialHappyEyeballsNameContext: %s/%s", network, addr) + + if network != "tcp" { + return nil, fmt.Errorf("unsupported network: only tcp is supported, not %q", network) + } + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + + var addrs []net.Addr + + ipv4, ipv6, err := dns.ResolveDualStack(host) + if err != nil { + return nil, err + } + + if ipv4 != "" { + logger.V(1).Debugf("dialHappyEyeballsNameContext: ipv4: %s", ipv4) + addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv4)}) + } + if ipv6 != "" { + logger.V(1).Debugf("dialHappyEyeballsNameContext: ipv6: %s", ipv6) + addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv6)}) + } + + return dialHappyEyeballs(ctx, addrs, uint16(port)) +} + func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.Conn, error) { + if !ipv4Enable && !ipv6Enable { + return nil, fmt.Errorf("cannot dial happy eyeballs connection: at least one address family must be enabled") + } + + logger.V(2).Debugf("ipv4: %t / ipv6: %t", ipv4Enable, ipv6Enable) + dialer := &net.Dialer{} lock := &sync.Mutex{} defer lock.Unlock() @@ -23,10 +79,20 @@ func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net. dialContext, dialCancel := context.WithCancel(ctx) dialFunc := func(addr net.Addr) { - addr = &net.TCPAddr{ - IP: net.ParseIP(addr.String()), - Port: int(port), + if ipaddr, ok := addr.(*net.IPAddr); ok { + if ipaddr.IP.To4() != nil && !ipv4Enable { + logger.V(2).Debugf("skip address %s: ipv4 not enabled", ipaddr.IP) + return + } else if ipaddr.IP.To4() == nil && ipaddr.IP.To16() != nil && !ipv6Enable { + logger.V(2).Debugf("skip address %s: ipv6 not enabled", ipaddr.IP) + return + } + addr = &net.TCPAddr{ + IP: ipaddr.IP, + Port: int(port), + } } + logger.V(1).Debugf("dialHappyEyeballsContext: dialing: %s", addr) conn, err := dialer.DialContext(dialContext, addr.Network(), addr.String()) lock.Lock() defer lock.Unlock() @@ -36,8 +102,11 @@ func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net. } return } + logger.V(1).Debugf("dialHappyEyeballsContext: got connection via: %s", addr) if !done { connChan <- conn + } else { + conn.Close() } } @@ -58,7 +127,7 @@ func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net. } case <-ctx.Done(): dialCancel() - return nil, context.Canceled + return nil, ctx.Err() } } } -- 2.50.1