From d1eabcd312b62a21a56f10fc0b0dadfd90ff4885 Mon Sep 17 00:00:00 2001 From: Jason Lyu Date: Wed, 16 Apr 2025 04:35:59 +0800 Subject: [PATCH] Refactor(dialer): use `DefaultDialer` (#465) --- dialer/dialer.go | 49 ++++++++++++++++++++++++++++++++---------------- dns/resolver.go | 2 +- engine/engine.go | 6 +++--- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/dialer/dialer.go b/dialer/dialer.go index c4ac76a..e8f642d 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -8,11 +8,18 @@ import ( "go.uber.org/atomic" ) -var ( - DefaultInterfaceName = atomic.NewString("") - DefaultInterfaceIndex = atomic.NewInt32(0) - DefaultRoutingMark = atomic.NewInt32(0) -) +// DefaultDialer is the default Dialer and is used by DialContext and ListenPacket. +var DefaultDialer = &Dialer{ + InterfaceName: atomic.NewString(""), + InterfaceIndex: atomic.NewInt32(0), + RoutingMark: atomic.NewInt32(0), +} + +type Dialer struct { + InterfaceName *atomic.String + InterfaceIndex *atomic.Int32 + RoutingMark *atomic.Int32 +} type Options struct { // InterfaceName is the name of interface/device to bind. @@ -31,15 +38,25 @@ type Options struct { RoutingMark int } +// DialContext is a wrapper around DefaultDialer.DialContext. func DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return DialContextWithOptions(ctx, network, address, &Options{ - InterfaceName: DefaultInterfaceName.Load(), - InterfaceIndex: int(DefaultInterfaceIndex.Load()), - RoutingMark: int(DefaultRoutingMark.Load()), + return DefaultDialer.DialContext(ctx, network, address) +} + +// ListenPacket is a wrapper around DefaultDialer.ListenPacket. +func ListenPacket(network, address string) (net.PacketConn, error) { + return DefaultDialer.ListenPacket(network, address) +} + +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContextWithOptions(ctx, network, address, &Options{ + InterfaceName: d.InterfaceName.Load(), + InterfaceIndex: int(d.InterfaceIndex.Load()), + RoutingMark: int(d.RoutingMark.Load()), }) } -func DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) { +func (_ *Dialer) DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) { d := &net.Dialer{ Control: func(network, address string, c syscall.RawConn) error { return setSocketOptions(network, address, c, opts) @@ -48,15 +65,15 @@ func DialContextWithOptions(ctx context.Context, network, address string, opts * return d.DialContext(ctx, network, address) } -func ListenPacket(network, address string) (net.PacketConn, error) { - return ListenPacketWithOptions(network, address, &Options{ - InterfaceName: DefaultInterfaceName.Load(), - InterfaceIndex: int(DefaultInterfaceIndex.Load()), - RoutingMark: int(DefaultRoutingMark.Load()), +func (d *Dialer) ListenPacket(network, address string) (net.PacketConn, error) { + return d.ListenPacketWithOptions(network, address, &Options{ + InterfaceName: d.InterfaceName.Load(), + InterfaceIndex: int(d.InterfaceIndex.Load()), + RoutingMark: int(d.RoutingMark.Load()), }) } -func ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) { +func (_ *Dialer) ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) { lc := &net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { return setSocketOptions(network, address, c, opts) diff --git a/dns/resolver.go b/dns/resolver.go index 8748aa4..858223b 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -10,5 +10,5 @@ func init() { // We must use this DialContext to query DNS // when using net default resolver. net.DefaultResolver.PreferGo = true - net.DefaultResolver.Dial = dialer.DialContext + net.DefaultResolver.Dial = dialer.DefaultDialer.DialContext } diff --git a/engine/engine.go b/engine/engine.go index 9e8cfe6..3843ead 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -117,13 +117,13 @@ func general(k *Key) error { if err != nil { return err } - dialer.DefaultInterfaceName.Store(iface.Name) - dialer.DefaultInterfaceIndex.Store(int32(iface.Index)) + dialer.DefaultDialer.InterfaceName.Store(iface.Name) + dialer.DefaultDialer.InterfaceIndex.Store(int32(iface.Index)) log.Infof("[DIALER] bind to interface: %s", k.Interface) } if k.Mark != 0 { - dialer.DefaultRoutingMark.Store(int32(k.Mark)) + dialer.DefaultDialer.RoutingMark.Store(int32(k.Mark)) log.Infof("[DIALER] set fwmark: %#x", k.Mark) }