]> go.fuhry.dev Git - runtime.git/commitdiff
sase/happy_eyeballs: use HE for client dial; add flags to block v4 or v6 connections
authorDan Fuhry <dan@fuhry.com>
Tue, 27 Feb 2024 21:03:31 +0000 (16:03 -0500)
committerDan Fuhry <dan@fuhry.com>
Tue, 27 Feb 2024 21:03:31 +0000 (16:03 -0500)
sase/client.go
sase/happy_eyeballs.go

index 0ed4f39a7fda39356bd8d2e451d24f1171fe6e71..6191bbdb5837f633754c43b58424301dbfc6a81e 100644 (file)
@@ -28,6 +28,7 @@ func NewClient(id mtls.Identity) (Client, error) {
        }
        tlsConfig.RootCAs = nil
        wsClient := &websocket.Dialer{
+               NetDialContext:  dialHappyEyeballsNameContext,
                TLSClientConfig: tlsConfig,
        }
 
index 9eb789afc38fb11609d1d0ac7a5a779d48b7e323..5ebd07e8c919b38697027a3e3b625622df6be355 100644 (file)
@@ -2,13 +2,69 @@ package sase
 
 import (
        "context"
+       "flag"
        "fmt"
        "net"
+       "strconv"
        "strings"
        "sync"
+
+       "go.fuhry.dev/runtime/net/dns"
+       "go.fuhry.dev/runtime/utils/log"
+)
+
+var logger = log.WithPrefix("sase/happy_eyeballs")
+var (
+       ipv4Enable,
+       ipv6Enable bool
 )
 
+func init() {
+       flag.BoolVar(&ipv4Enable, "happy-eyeballs.ipv4", true, "allow happy eyeballs to dial IPv4 addresses")
+       flag.BoolVar(&ipv6Enable, "happy-eyeballs.ipv6", true, "allow happy eyeballs to dial IPv6 addresses")
+}
+
+func dialHappyEyeballsNameContext(ctx context.Context, network, addr string) (net.Conn, error) {
+       logger.V(1).Debugf("dialHappyEyeballsNameContext: %s/%s", network, addr)
+
+       if network != "tcp" {
+               return nil, fmt.Errorf("unsupported network: only tcp is supported, not %q", network)
+       }
+       host, portStr, err := net.SplitHostPort(addr)
+       if err != nil {
+               return nil, err
+       }
+       port, err := strconv.Atoi(portStr)
+       if err != nil {
+               return nil, err
+       }
+
+       var addrs []net.Addr
+
+       ipv4, ipv6, err := dns.ResolveDualStack(host)
+       if err != nil {
+               return nil, err
+       }
+
+       if ipv4 != "" {
+               logger.V(1).Debugf("dialHappyEyeballsNameContext: ipv4: %s", ipv4)
+               addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv4)})
+       }
+       if ipv6 != "" {
+               logger.V(1).Debugf("dialHappyEyeballsNameContext: ipv6: %s", ipv6)
+               addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv6)})
+       }
+
+       return dialHappyEyeballs(ctx, addrs, uint16(port))
+}
+
 func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.Conn, error) {
+       if !ipv4Enable && !ipv6Enable {
+               return nil, fmt.Errorf("cannot dial happy eyeballs connection: at least one address family must be enabled")
+       }
+
+       logger.V(2).Debugf("ipv4: %t / ipv6: %t", ipv4Enable, ipv6Enable)
+
        dialer := &net.Dialer{}
        lock := &sync.Mutex{}
        defer lock.Unlock()
@@ -23,10 +79,20 @@ func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.
        dialContext, dialCancel := context.WithCancel(ctx)
 
        dialFunc := func(addr net.Addr) {
-               addr = &net.TCPAddr{
-                       IP:   net.ParseIP(addr.String()),
-                       Port: int(port),
+               if ipaddr, ok := addr.(*net.IPAddr); ok {
+                       if ipaddr.IP.To4() != nil && !ipv4Enable {
+                               logger.V(2).Debugf("skip address %s: ipv4 not enabled", ipaddr.IP)
+                               return
+                       } else if ipaddr.IP.To4() == nil && ipaddr.IP.To16() != nil && !ipv6Enable {
+                               logger.V(2).Debugf("skip address %s: ipv6 not enabled", ipaddr.IP)
+                               return
+                       }
+                       addr = &net.TCPAddr{
+                               IP:   ipaddr.IP,
+                               Port: int(port),
+                       }
                }
+               logger.V(1).Debugf("dialHappyEyeballsContext: dialing: %s", addr)
                conn, err := dialer.DialContext(dialContext, addr.Network(), addr.String())
                lock.Lock()
                defer lock.Unlock()
@@ -36,8 +102,11 @@ func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.
                        }
                        return
                }
+               logger.V(1).Debugf("dialHappyEyeballsContext: got connection via: %s", addr)
                if !done {
                        connChan <- conn
+               } else {
+                       conn.Close()
                }
        }
 
@@ -58,7 +127,7 @@ func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.
                        }
                case <-ctx.Done():
                        dialCancel()
-                       return nil, context.Canceled
+                       return nil, ctx.Err()
                }
        }
 }