Refactor(dialer): use DefaultDialer (#465)

This commit is contained in:
Jason Lyu
2025-04-16 04:35:59 +08:00
committed by GitHub
parent f135a13b33
commit d1eabcd312
3 changed files with 37 additions and 20 deletions

View File

@@ -8,11 +8,18 @@ import (
"go.uber.org/atomic" "go.uber.org/atomic"
) )
var ( // DefaultDialer is the default Dialer and is used by DialContext and ListenPacket.
DefaultInterfaceName = atomic.NewString("") var DefaultDialer = &Dialer{
DefaultInterfaceIndex = atomic.NewInt32(0) InterfaceName: atomic.NewString(""),
DefaultRoutingMark = atomic.NewInt32(0) InterfaceIndex: atomic.NewInt32(0),
) RoutingMark: atomic.NewInt32(0),
}
type Dialer struct {
InterfaceName *atomic.String
InterfaceIndex *atomic.Int32
RoutingMark *atomic.Int32
}
type Options struct { type Options struct {
// InterfaceName is the name of interface/device to bind. // InterfaceName is the name of interface/device to bind.
@@ -31,15 +38,25 @@ type Options struct {
RoutingMark int RoutingMark int
} }
// DialContext is a wrapper around DefaultDialer.DialContext.
func DialContext(ctx context.Context, network, address string) (net.Conn, error) { func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return DialContextWithOptions(ctx, network, address, &Options{ return DefaultDialer.DialContext(ctx, network, address)
InterfaceName: DefaultInterfaceName.Load(), }
InterfaceIndex: int(DefaultInterfaceIndex.Load()),
RoutingMark: int(DefaultRoutingMark.Load()), // ListenPacket is a wrapper around DefaultDialer.ListenPacket.
func ListenPacket(network, address string) (net.PacketConn, error) {
return DefaultDialer.ListenPacket(network, address)
}
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContextWithOptions(ctx, network, address, &Options{
InterfaceName: d.InterfaceName.Load(),
InterfaceIndex: int(d.InterfaceIndex.Load()),
RoutingMark: int(d.RoutingMark.Load()),
}) })
} }
func DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) { func (_ *Dialer) DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) {
d := &net.Dialer{ d := &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts) return setSocketOptions(network, address, c, opts)
@@ -48,15 +65,15 @@ func DialContextWithOptions(ctx context.Context, network, address string, opts *
return d.DialContext(ctx, network, address) return d.DialContext(ctx, network, address)
} }
func ListenPacket(network, address string) (net.PacketConn, error) { func (d *Dialer) ListenPacket(network, address string) (net.PacketConn, error) {
return ListenPacketWithOptions(network, address, &Options{ return d.ListenPacketWithOptions(network, address, &Options{
InterfaceName: DefaultInterfaceName.Load(), InterfaceName: d.InterfaceName.Load(),
InterfaceIndex: int(DefaultInterfaceIndex.Load()), InterfaceIndex: int(d.InterfaceIndex.Load()),
RoutingMark: int(DefaultRoutingMark.Load()), RoutingMark: int(d.RoutingMark.Load()),
}) })
} }
func ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) { func (_ *Dialer) ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) {
lc := &net.ListenConfig{ lc := &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts) return setSocketOptions(network, address, c, opts)

View File

@@ -10,5 +10,5 @@ func init() {
// We must use this DialContext to query DNS // We must use this DialContext to query DNS
// when using net default resolver. // when using net default resolver.
net.DefaultResolver.PreferGo = true net.DefaultResolver.PreferGo = true
net.DefaultResolver.Dial = dialer.DialContext net.DefaultResolver.Dial = dialer.DefaultDialer.DialContext
} }

View File

@@ -117,13 +117,13 @@ func general(k *Key) error {
if err != nil { if err != nil {
return err return err
} }
dialer.DefaultInterfaceName.Store(iface.Name) dialer.DefaultDialer.InterfaceName.Store(iface.Name)
dialer.DefaultInterfaceIndex.Store(int32(iface.Index)) dialer.DefaultDialer.InterfaceIndex.Store(int32(iface.Index))
log.Infof("[DIALER] bind to interface: %s", k.Interface) log.Infof("[DIALER] bind to interface: %s", k.Interface)
} }
if k.Mark != 0 { if k.Mark != 0 {
dialer.DefaultRoutingMark.Store(int32(k.Mark)) dialer.DefaultDialer.RoutingMark.Store(int32(k.Mark))
log.Infof("[DIALER] set fwmark: %#x", k.Mark) log.Infof("[DIALER] set fwmark: %#x", k.Mark)
} }