From a4bedf6080319be0dde7f77b109e247a3af6ffe3 Mon Sep 17 00:00:00 2001 From: xjasonlyu Date: Tue, 29 Mar 2022 00:02:19 +0800 Subject: [PATCH] Refactor: new dialer impl --- common/singledo/singledo.go | 66 ++++++++++++++++++++++++++++ common/singledo/singledo_test.go | 69 ++++++++++++++++++++++++++++++ component/dialer/bind.go | 21 --------- component/dialer/bind_darwin.go | 32 -------------- component/dialer/bind_linux.go | 27 ------------ component/dialer/bind_others.go | 15 ------- component/dialer/control.go | 35 --------------- component/dialer/dialer.go | 48 ++++++++++++++++++--- component/dialer/fwmark.go | 14 ------ component/dialer/fwmark_linux.go | 21 --------- component/dialer/fwmark_others.go | 14 ------ component/dialer/sockopt.go | 19 ++++++++ component/dialer/sockopt_darwin.go | 61 ++++++++++++++++++++++++++ component/dialer/sockopt_linux.go | 39 +++++++++++++++++ component/dialer/sockopt_others.go | 7 +++ engine/engine.go | 21 +++------ 16 files changed, 310 insertions(+), 199 deletions(-) create mode 100755 common/singledo/singledo.go create mode 100755 common/singledo/singledo_test.go delete mode 100644 component/dialer/bind.go delete mode 100644 component/dialer/bind_darwin.go delete mode 100644 component/dialer/bind_linux.go delete mode 100644 component/dialer/bind_others.go delete mode 100644 component/dialer/control.go delete mode 100644 component/dialer/fwmark.go delete mode 100644 component/dialer/fwmark_linux.go delete mode 100644 component/dialer/fwmark_others.go create mode 100644 component/dialer/sockopt.go create mode 100644 component/dialer/sockopt_darwin.go create mode 100644 component/dialer/sockopt_linux.go create mode 100644 component/dialer/sockopt_others.go diff --git a/common/singledo/singledo.go b/common/singledo/singledo.go new file mode 100755 index 0000000..3f3ece2 --- /dev/null +++ b/common/singledo/singledo.go @@ -0,0 +1,66 @@ +package singledo + +// Ref: github.com/Dreamacro/clash/common/singledo + +import ( + "sync" + "time" +) + +type call struct { + wg sync.WaitGroup + val any + err error +} + +type Single struct { + mux sync.Mutex + last time.Time + wait time.Duration + call *call + result *Result +} + +type Result struct { + Val any + Err error +} + +// Do single.Do likes sync.singleFlight +//lint:ignore ST1008 it likes sync.singleFlight +func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) { + s.mux.Lock() + now := time.Now() + if now.Before(s.last.Add(s.wait)) { + s.mux.Unlock() + return s.result.Val, s.result.Err, true + } + + if call := s.call; call != nil { + s.mux.Unlock() + call.wg.Wait() + return call.val, call.err, true + } + + call := &call{} + call.wg.Add(1) + s.call = call + s.mux.Unlock() + call.val, call.err = fn() + call.wg.Done() + + s.mux.Lock() + s.call = nil + s.result = &Result{call.val, call.err} + s.last = now + s.mux.Unlock() + return call.val, call.err, false +} + +func (s *Single) Reset() { + s.last = time.Time{} +} + +func NewSingle(wait time.Duration) *Single { + return &Single{wait: wait} +} diff --git a/common/singledo/singledo_test.go b/common/singledo/singledo_test.go new file mode 100755 index 0000000..71b6ac9 --- /dev/null +++ b/common/singledo/singledo_test.go @@ -0,0 +1,69 @@ +package singledo + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestBasic(t *testing.T) { + single := NewSingle(time.Millisecond * 30) + foo := 0 + shardCount := atomic.NewInt32(0) + call := func() (any, error) { + foo++ + time.Sleep(time.Millisecond * 5) + return nil, nil + } + + var wg sync.WaitGroup + const n = 5 + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + _, _, shard := single.Do(call) + if shard { + shardCount.Inc() + } + wg.Done() + }() + } + + wg.Wait() + assert.Equal(t, 1, foo) + assert.Equal(t, int32(4), shardCount.Load()) +} + +func TestTimer(t *testing.T) { + single := NewSingle(time.Millisecond * 30) + foo := 0 + call := func() (any, error) { + foo++ + return nil, nil + } + + single.Do(call) + time.Sleep(10 * time.Millisecond) + _, _, shard := single.Do(call) + + assert.Equal(t, 1, foo) + assert.True(t, shard) +} + +func TestReset(t *testing.T) { + single := NewSingle(time.Millisecond * 30) + foo := 0 + call := func() (any, error) { + foo++ + return nil, nil + } + + single.Do(call) + single.Reset() + single.Do(call) + + assert.Equal(t, 2, foo) +} diff --git a/component/dialer/bind.go b/component/dialer/bind.go deleted file mode 100644 index fd3a098..0000000 --- a/component/dialer/bind.go +++ /dev/null @@ -1,21 +0,0 @@ -package dialer - -import ( - "net" - "sync" -) - -var _bindOnce sync.Once - -// BindToInterface binds dialer to specific interface. -func BindToInterface(name string) error { - i, err := net.InterfaceByName(name) - if err != nil { - return err - } - - _bindOnce.Do(func() { - addControl(bindToInterface(i)) - }) - return nil -} diff --git a/component/dialer/bind_darwin.go b/component/dialer/bind_darwin.go deleted file mode 100644 index b4901b0..0000000 --- a/component/dialer/bind_darwin.go +++ /dev/null @@ -1,32 +0,0 @@ -package dialer - -import ( - "net" - "syscall" - - "golang.org/x/sys/unix" -) - -func bindToInterface(i *net.Interface) controlFunc { - return func(network, address string, c syscall.RawConn) (err error) { - host, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { - return - } - - var innerErr error - err = c.Control(func(fd uintptr) { - switch network { - case "tcp4", "udp4": - innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, i.Index) - case "tcp6", "udp6": - innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, i.Index) - } - }) - - if innerErr != nil { - err = innerErr - } - return - } -} diff --git a/component/dialer/bind_linux.go b/component/dialer/bind_linux.go deleted file mode 100644 index 7237fb2..0000000 --- a/component/dialer/bind_linux.go +++ /dev/null @@ -1,27 +0,0 @@ -package dialer - -import ( - "net" - "syscall" - - "golang.org/x/sys/unix" -) - -func bindToInterface(i *net.Interface) controlFunc { - return func(network, address string, c syscall.RawConn) (err error) { - host, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { - return nil - } - - var innerErr error - err = c.Control(func(fd uintptr) { - innerErr = unix.BindToDevice(int(fd), i.Name) - }) - - if innerErr != nil { - err = innerErr - } - return - } -} diff --git a/component/dialer/bind_others.go b/component/dialer/bind_others.go deleted file mode 100644 index 8348e1b..0000000 --- a/component/dialer/bind_others.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build !linux && !darwin - -package dialer - -import ( - "errors" - "net" - "syscall" -) - -func bindToInterface(_ *net.Interface) controlFunc { - return func(string, string, syscall.RawConn) error { - return errors.New("unsupported platform") - } -} diff --git a/component/dialer/control.go b/component/dialer/control.go deleted file mode 100644 index 1f000ec..0000000 --- a/component/dialer/control.go +++ /dev/null @@ -1,35 +0,0 @@ -package dialer - -import ( - "errors" - "net" - "syscall" -) - -type controlFunc func(string, string, syscall.RawConn) error - -var _controlPool = make([]controlFunc, 0, 2) - -func addControl(f controlFunc) { - _controlPool = append(_controlPool, f) -} - -func setControl(i any) { - control := func(address, network string, c syscall.RawConn) error { - for _, f := range _controlPool { - if err := f(address, network, c); err != nil { - return err - } - } - return nil - } - - switch v := i.(type) { - case *net.Dialer: - v.Control = control - case *net.ListenConfig: - v.Control = control - default: - panic(errors.New("wrong type")) - } -} diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 72e5aa6..d313713 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -3,20 +3,56 @@ package dialer import ( "context" "net" + "syscall" + + "go.uber.org/atomic" ) -func Dial(network, address string) (net.Conn, error) { - return DialContext(context.Background(), network, address) +var ( + DefaultInterfaceName = atomic.NewString("") + DefaultRoutingMark = atomic.NewInt32(0) +) + +type Options struct { + // InterfaceName is the name of interface/device to bind. + // If a socket is bound to an interface, only packets received + // from that particular interface are processed by the socket. + InterfaceName string + + // RoutingMark is the mark for each packet sent through this + // socket. Changing the mark can be used for mark-based routing + // without netfilter or for packet filtering. + RoutingMark int } func DialContext(ctx context.Context, network, address string) (net.Conn, error) { - d := &net.Dialer{} - setControl(d) + return DialContextWithOptions(ctx, network, address, &Options{ + InterfaceName: DefaultInterfaceName.Load(), + RoutingMark: int(DefaultRoutingMark.Load()), + }) +} + +func DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) { + d := &net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + return setSocketOptions(network, address, c, opts) + }, + } return d.DialContext(ctx, network, address) } func ListenPacket(network, address string) (net.PacketConn, error) { - lc := &net.ListenConfig{} - setControl(lc) + return ListenPacketWithOptions(network, address, &Options{ + InterfaceName: DefaultInterfaceName.Load(), + RoutingMark: int(DefaultRoutingMark.Load()), + }) +} + +func ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) { + lc := &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + return setSocketOptions(network, address, c, opts) + }, + } return lc.ListenPacket(context.Background(), network, address) } diff --git a/component/dialer/fwmark.go b/component/dialer/fwmark.go deleted file mode 100644 index c7e882b..0000000 --- a/component/dialer/fwmark.go +++ /dev/null @@ -1,14 +0,0 @@ -package dialer - -import ( - "sync" -) - -var _setOnce sync.Once - -// SetMark sets the mark for each packet sent through this dialer(socket). -func SetMark(i int) { - _setOnce.Do(func() { - addControl(setMark(i)) - }) -} diff --git a/component/dialer/fwmark_linux.go b/component/dialer/fwmark_linux.go deleted file mode 100644 index 8b0cbbd..0000000 --- a/component/dialer/fwmark_linux.go +++ /dev/null @@ -1,21 +0,0 @@ -package dialer - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func setMark(m int) controlFunc { - return func(_, _ string, c syscall.RawConn) (err error) { - var innerErr error - err = c.Control(func(fd uintptr) { - innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, m) - }) - - if innerErr != nil { - err = innerErr - } - return - } -} diff --git a/component/dialer/fwmark_others.go b/component/dialer/fwmark_others.go deleted file mode 100644 index b00b301..0000000 --- a/component/dialer/fwmark_others.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !linux - -package dialer - -import ( - "errors" - "syscall" -) - -func setMark(_ int) controlFunc { - return func(string, string, syscall.RawConn) error { - return errors.New("fwmark: linux only") - } -} diff --git a/component/dialer/sockopt.go b/component/dialer/sockopt.go new file mode 100644 index 0000000..76c3c41 --- /dev/null +++ b/component/dialer/sockopt.go @@ -0,0 +1,19 @@ +package dialer + +func isTCPSocket(network string) bool { + switch network { + case "tcp", "tcp4", "tcp6": + return true + default: + return false + } +} + +func isUDPSocket(network string) bool { + switch network { + case "udp", "udp4", "udp6": + return true + default: + return false + } +} diff --git a/component/dialer/sockopt_darwin.go b/component/dialer/sockopt_darwin.go new file mode 100644 index 0000000..7aea95a --- /dev/null +++ b/component/dialer/sockopt_darwin.go @@ -0,0 +1,61 @@ +package dialer + +import ( + "net" + "syscall" + "time" + + "github.com/xjasonlyu/tun2socks/v2/common/singledo" + + "golang.org/x/sys/unix" +) + +var interfaces = singledo.NewSingle(30 * time.Second) + +func resolveInterfaceByName(name string) (*net.Interface, error) { + value, err, _ := interfaces.Do(func() (any, error) { + return net.InterfaceByName(name) + }) + if err != nil { + return nil, err + } + return value.(*net.Interface), nil +} + +func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { + if !isTCPSocket(network) && !isUDPSocket(network) { + return + } + + var innerErr error + err = c.Control(func(fd uintptr) { + // must be GlobalUnicast. + host, _, _ := net.SplitHostPort(address) + if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { + return + } + + if opts.InterfaceName != "" { + var iface *net.Interface + iface, innerErr = resolveInterfaceByName(opts.InterfaceName) + if innerErr != nil { + return + } + + switch network { + case "tcp4", "udp4": + innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, iface.Index) + case "tcp6", "udp6": + innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, iface.Index) + } + if innerErr != nil { + return + } + } + }) + + if innerErr != nil { + err = innerErr + } + return +} diff --git a/component/dialer/sockopt_linux.go b/component/dialer/sockopt_linux.go new file mode 100644 index 0000000..0dd5217 --- /dev/null +++ b/component/dialer/sockopt_linux.go @@ -0,0 +1,39 @@ +package dialer + +import ( + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { + if !isTCPSocket(network) && !isUDPSocket(network) { + return + } + + var innerErr error + err = c.Control(func(fd uintptr) { + // must be GlobalUnicast. + host, _, _ := net.SplitHostPort(address) + if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { + return + } + + if opts.InterfaceName != "" { + if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil { + return + } + } + if opts.RoutingMark != 0 { + if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil { + return + } + } + }) + + if innerErr != nil { + err = innerErr + } + return +} diff --git a/component/dialer/sockopt_others.go b/component/dialer/sockopt_others.go new file mode 100644 index 0000000..00224d6 --- /dev/null +++ b/component/dialer/sockopt_others.go @@ -0,0 +1,7 @@ +//go:build !linux && !darwin + +package dialer + +func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) error { + return nil +} diff --git a/engine/engine.go b/engine/engine.go index b2d302a..0ee4362 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -65,8 +65,7 @@ func (e *engine) start() error { for _, f := range []func() error{ e.applyLogLevel, - e.applyMark, - e.applyInterface, + e.applyDialer, e.applyStats, e.applyUDPTimeout, e.applyProxy, @@ -104,21 +103,15 @@ func (e *engine) applyLogLevel() error { return nil } -func (e *engine) applyMark() error { - if e.Mark != 0 { - dialer.SetMark(e.Mark) - log.Infof("[DIALER] set fwmark: %#x", e.Mark) - } - return nil -} - -func (e *engine) applyInterface() error { +func (e *engine) applyDialer() error { if e.Interface != "" { - if err := dialer.BindToInterface(e.Interface); err != nil { - return err - } + dialer.DefaultInterfaceName.Store(e.Interface) log.Infof("[DIALER] bind to interface: %s", e.Interface) } + if e.Mark != 0 { + dialer.DefaultRoutingMark.Store(int32(e.Mark)) + log.Infof("[DIALER] set fwmark: %#x", e.Mark) + } return nil }