import (
"context"
+ "flag"
"fmt"
"net"
+ "strconv"
"strings"
"sync"
+
+ "go.fuhry.dev/runtime/net/dns"
+ "go.fuhry.dev/runtime/utils/log"
+)
+
+var logger = log.WithPrefix("sase/happy_eyeballs")
+var (
+ ipv4Enable,
+ ipv6Enable bool
)
+func init() {
+ flag.BoolVar(&ipv4Enable, "happy-eyeballs.ipv4", true, "allow happy eyeballs to dial IPv4 addresses")
+ flag.BoolVar(&ipv6Enable, "happy-eyeballs.ipv6", true, "allow happy eyeballs to dial IPv6 addresses")
+}
+
+func dialHappyEyeballsNameContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ logger.V(1).Debugf("dialHappyEyeballsNameContext: %s/%s", network, addr)
+
+ if network != "tcp" {
+ return nil, fmt.Errorf("unsupported network: only tcp is supported, not %q", network)
+ }
+ host, portStr, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return nil, err
+ }
+
+ var addrs []net.Addr
+
+ ipv4, ipv6, err := dns.ResolveDualStack(host)
+ if err != nil {
+ return nil, err
+ }
+
+ if ipv4 != "" {
+ logger.V(1).Debugf("dialHappyEyeballsNameContext: ipv4: %s", ipv4)
+ addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv4)})
+ }
+ if ipv6 != "" {
+ logger.V(1).Debugf("dialHappyEyeballsNameContext: ipv6: %s", ipv6)
+ addrs = append(addrs, &net.IPAddr{IP: net.ParseIP(ipv6)})
+ }
+
+ return dialHappyEyeballs(ctx, addrs, uint16(port))
+}
+
func dialHappyEyeballs(ctx context.Context, addrs []net.Addr, port uint16) (net.Conn, error) {
+ if !ipv4Enable && !ipv6Enable {
+ return nil, fmt.Errorf("cannot dial happy eyeballs connection: at least one address family must be enabled")
+ }
+
+ logger.V(2).Debugf("ipv4: %t / ipv6: %t", ipv4Enable, ipv6Enable)
+
dialer := &net.Dialer{}
lock := &sync.Mutex{}
defer lock.Unlock()
dialContext, dialCancel := context.WithCancel(ctx)
dialFunc := func(addr net.Addr) {
- addr = &net.TCPAddr{
- IP: net.ParseIP(addr.String()),
- Port: int(port),
+ if ipaddr, ok := addr.(*net.IPAddr); ok {
+ if ipaddr.IP.To4() != nil && !ipv4Enable {
+ logger.V(2).Debugf("skip address %s: ipv4 not enabled", ipaddr.IP)
+ return
+ } else if ipaddr.IP.To4() == nil && ipaddr.IP.To16() != nil && !ipv6Enable {
+ logger.V(2).Debugf("skip address %s: ipv6 not enabled", ipaddr.IP)
+ return
+ }
+ addr = &net.TCPAddr{
+ IP: ipaddr.IP,
+ Port: int(port),
+ }
}
+ logger.V(1).Debugf("dialHappyEyeballsContext: dialing: %s", addr)
conn, err := dialer.DialContext(dialContext, addr.Network(), addr.String())
lock.Lock()
defer lock.Unlock()
}
return
}
+ logger.V(1).Debugf("dialHappyEyeballsContext: got connection via: %s", addr)
if !done {
connChan <- conn
+ } else {
+ conn.Close()
}
}
}
case <-ctx.Done():
dialCancel()
- return nil, context.Canceled
+ return nil, ctx.Err()
}
}
}