mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2025-10-06 09:16:58 +08:00
Refactor: new dialer impl
This commit is contained in:
66
common/singledo/singledo.go
Executable file
66
common/singledo/singledo.go
Executable file
@@ -0,0 +1,66 @@
|
|||||||
|
package singledo
|
||||||
|
|
||||||
|
// Ref: github.com/Dreamacro/clash/common/singledo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type call struct {
|
||||||
|
wg sync.WaitGroup
|
||||||
|
val any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Single struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
last time.Time
|
||||||
|
wait time.Duration
|
||||||
|
call *call
|
||||||
|
result *Result
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Val any
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do single.Do likes sync.singleFlight
|
||||||
|
//lint:ignore ST1008 it likes sync.singleFlight
|
||||||
|
func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) {
|
||||||
|
s.mux.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
if now.Before(s.last.Add(s.wait)) {
|
||||||
|
s.mux.Unlock()
|
||||||
|
return s.result.Val, s.result.Err, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if call := s.call; call != nil {
|
||||||
|
s.mux.Unlock()
|
||||||
|
call.wg.Wait()
|
||||||
|
return call.val, call.err, true
|
||||||
|
}
|
||||||
|
|
||||||
|
call := &call{}
|
||||||
|
call.wg.Add(1)
|
||||||
|
s.call = call
|
||||||
|
s.mux.Unlock()
|
||||||
|
call.val, call.err = fn()
|
||||||
|
call.wg.Done()
|
||||||
|
|
||||||
|
s.mux.Lock()
|
||||||
|
s.call = nil
|
||||||
|
s.result = &Result{call.val, call.err}
|
||||||
|
s.last = now
|
||||||
|
s.mux.Unlock()
|
||||||
|
return call.val, call.err, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Single) Reset() {
|
||||||
|
s.last = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSingle(wait time.Duration) *Single {
|
||||||
|
return &Single{wait: wait}
|
||||||
|
}
|
69
common/singledo/singledo_test.go
Executable file
69
common/singledo/singledo_test.go
Executable file
@@ -0,0 +1,69 @@
|
|||||||
|
package singledo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasic(t *testing.T) {
|
||||||
|
single := NewSingle(time.Millisecond * 30)
|
||||||
|
foo := 0
|
||||||
|
shardCount := atomic.NewInt32(0)
|
||||||
|
call := func() (any, error) {
|
||||||
|
foo++
|
||||||
|
time.Sleep(time.Millisecond * 5)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
const n = 5
|
||||||
|
wg.Add(n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
go func() {
|
||||||
|
_, _, shard := single.Do(call)
|
||||||
|
if shard {
|
||||||
|
shardCount.Inc()
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, 1, foo)
|
||||||
|
assert.Equal(t, int32(4), shardCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimer(t *testing.T) {
|
||||||
|
single := NewSingle(time.Millisecond * 30)
|
||||||
|
foo := 0
|
||||||
|
call := func() (any, error) {
|
||||||
|
foo++
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
single.Do(call)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
_, _, shard := single.Do(call)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, foo)
|
||||||
|
assert.True(t, shard)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReset(t *testing.T) {
|
||||||
|
single := NewSingle(time.Millisecond * 30)
|
||||||
|
foo := 0
|
||||||
|
call := func() (any, error) {
|
||||||
|
foo++
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
single.Do(call)
|
||||||
|
single.Reset()
|
||||||
|
single.Do(call)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, foo)
|
||||||
|
}
|
@@ -1,21 +0,0 @@
|
|||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _bindOnce sync.Once
|
|
||||||
|
|
||||||
// BindToInterface binds dialer to specific interface.
|
|
||||||
func BindToInterface(name string) error {
|
|
||||||
i, err := net.InterfaceByName(name)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_bindOnce.Do(func() {
|
|
||||||
addControl(bindToInterface(i))
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
@@ -1,32 +0,0 @@
|
|||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func bindToInterface(i *net.Interface) controlFunc {
|
|
||||||
return func(network, address string, c syscall.RawConn) (err error) {
|
|
||||||
host, _, _ := net.SplitHostPort(address)
|
|
||||||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var innerErr error
|
|
||||||
err = c.Control(func(fd uintptr) {
|
|
||||||
switch network {
|
|
||||||
case "tcp4", "udp4":
|
|
||||||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, i.Index)
|
|
||||||
case "tcp6", "udp6":
|
|
||||||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, i.Index)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if innerErr != nil {
|
|
||||||
err = innerErr
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,27 +0,0 @@
|
|||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func bindToInterface(i *net.Interface) controlFunc {
|
|
||||||
return func(network, address string, c syscall.RawConn) (err error) {
|
|
||||||
host, _, _ := net.SplitHostPort(address)
|
|
||||||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var innerErr error
|
|
||||||
err = c.Control(func(fd uintptr) {
|
|
||||||
innerErr = unix.BindToDevice(int(fd), i.Name)
|
|
||||||
})
|
|
||||||
|
|
||||||
if innerErr != nil {
|
|
||||||
err = innerErr
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,15 +0,0 @@
|
|||||||
//go:build !linux && !darwin
|
|
||||||
|
|
||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
func bindToInterface(_ *net.Interface) controlFunc {
|
|
||||||
return func(string, string, syscall.RawConn) error {
|
|
||||||
return errors.New("unsupported platform")
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,35 +0,0 @@
|
|||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
type controlFunc func(string, string, syscall.RawConn) error
|
|
||||||
|
|
||||||
var _controlPool = make([]controlFunc, 0, 2)
|
|
||||||
|
|
||||||
func addControl(f controlFunc) {
|
|
||||||
_controlPool = append(_controlPool, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setControl(i any) {
|
|
||||||
control := func(address, network string, c syscall.RawConn) error {
|
|
||||||
for _, f := range _controlPool {
|
|
||||||
if err := f(address, network, c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := i.(type) {
|
|
||||||
case *net.Dialer:
|
|
||||||
v.Control = control
|
|
||||||
case *net.ListenConfig:
|
|
||||||
v.Control = control
|
|
||||||
default:
|
|
||||||
panic(errors.New("wrong type"))
|
|
||||||
}
|
|
||||||
}
|
|
@@ -3,20 +3,56 @@ package dialer
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Dial(network, address string) (net.Conn, error) {
|
var (
|
||||||
return DialContext(context.Background(), network, address)
|
DefaultInterfaceName = atomic.NewString("")
|
||||||
|
DefaultRoutingMark = atomic.NewInt32(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
// InterfaceName is the name of interface/device to bind.
|
||||||
|
// If a socket is bound to an interface, only packets received
|
||||||
|
// from that particular interface are processed by the socket.
|
||||||
|
InterfaceName string
|
||||||
|
|
||||||
|
// RoutingMark is the mark for each packet sent through this
|
||||||
|
// socket. Changing the mark can be used for mark-based routing
|
||||||
|
// without netfilter or for packet filtering.
|
||||||
|
RoutingMark int
|
||||||
}
|
}
|
||||||
|
|
||||||
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
d := &net.Dialer{}
|
return DialContextWithOptions(ctx, network, address, &Options{
|
||||||
setControl(d)
|
InterfaceName: DefaultInterfaceName.Load(),
|
||||||
|
RoutingMark: int(DefaultRoutingMark.Load()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) {
|
||||||
|
d := &net.Dialer{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
return setSocketOptions(network, address, c, opts)
|
||||||
|
},
|
||||||
|
}
|
||||||
return d.DialContext(ctx, network, address)
|
return d.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenPacket(network, address string) (net.PacketConn, error) {
|
func ListenPacket(network, address string) (net.PacketConn, error) {
|
||||||
lc := &net.ListenConfig{}
|
return ListenPacketWithOptions(network, address, &Options{
|
||||||
setControl(lc)
|
InterfaceName: DefaultInterfaceName.Load(),
|
||||||
|
RoutingMark: int(DefaultRoutingMark.Load()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) {
|
||||||
|
lc := &net.ListenConfig{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
return setSocketOptions(network, address, c, opts)
|
||||||
|
},
|
||||||
|
}
|
||||||
return lc.ListenPacket(context.Background(), network, address)
|
return lc.ListenPacket(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
@@ -1,14 +0,0 @@
|
|||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _setOnce sync.Once
|
|
||||||
|
|
||||||
// SetMark sets the mark for each packet sent through this dialer(socket).
|
|
||||||
func SetMark(i int) {
|
|
||||||
_setOnce.Do(func() {
|
|
||||||
addControl(setMark(i))
|
|
||||||
})
|
|
||||||
}
|
|
@@ -1,21 +0,0 @@
|
|||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setMark(m int) controlFunc {
|
|
||||||
return func(_, _ string, c syscall.RawConn) (err error) {
|
|
||||||
var innerErr error
|
|
||||||
err = c.Control(func(fd uintptr) {
|
|
||||||
innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, m)
|
|
||||||
})
|
|
||||||
|
|
||||||
if innerErr != nil {
|
|
||||||
err = innerErr
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,14 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setMark(_ int) controlFunc {
|
|
||||||
return func(string, string, syscall.RawConn) error {
|
|
||||||
return errors.New("fwmark: linux only")
|
|
||||||
}
|
|
||||||
}
|
|
19
component/dialer/sockopt.go
Normal file
19
component/dialer/sockopt.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package dialer
|
||||||
|
|
||||||
|
func isTCPSocket(network string) bool {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUDPSocket(network string) bool {
|
||||||
|
switch network {
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
61
component/dialer/sockopt_darwin.go
Normal file
61
component/dialer/sockopt_darwin.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/xjasonlyu/tun2socks/v2/common/singledo"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
var interfaces = singledo.NewSingle(30 * time.Second)
|
||||||
|
|
||||||
|
func resolveInterfaceByName(name string) (*net.Interface, error) {
|
||||||
|
value, err, _ := interfaces.Do(func() (any, error) {
|
||||||
|
return net.InterfaceByName(name)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return value.(*net.Interface), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
|
||||||
|
if !isTCPSocket(network) && !isUDPSocket(network) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var innerErr error
|
||||||
|
err = c.Control(func(fd uintptr) {
|
||||||
|
// must be GlobalUnicast.
|
||||||
|
host, _, _ := net.SplitHostPort(address)
|
||||||
|
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.InterfaceName != "" {
|
||||||
|
var iface *net.Interface
|
||||||
|
iface, innerErr = resolveInterfaceByName(opts.InterfaceName)
|
||||||
|
if innerErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch network {
|
||||||
|
case "tcp4", "udp4":
|
||||||
|
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, iface.Index)
|
||||||
|
case "tcp6", "udp6":
|
||||||
|
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, iface.Index)
|
||||||
|
}
|
||||||
|
if innerErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if innerErr != nil {
|
||||||
|
err = innerErr
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
39
component/dialer/sockopt_linux.go
Normal file
39
component/dialer/sockopt_linux.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
|
||||||
|
if !isTCPSocket(network) && !isUDPSocket(network) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var innerErr error
|
||||||
|
err = c.Control(func(fd uintptr) {
|
||||||
|
// must be GlobalUnicast.
|
||||||
|
host, _, _ := net.SplitHostPort(address)
|
||||||
|
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.InterfaceName != "" {
|
||||||
|
if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if opts.RoutingMark != 0 {
|
||||||
|
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if innerErr != nil {
|
||||||
|
err = innerErr
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
7
component/dialer/sockopt_others.go
Normal file
7
component/dialer/sockopt_others.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !linux && !darwin
|
||||||
|
|
||||||
|
package dialer
|
||||||
|
|
||||||
|
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) error {
|
||||||
|
return nil
|
||||||
|
}
|
@@ -65,8 +65,7 @@ func (e *engine) start() error {
|
|||||||
|
|
||||||
for _, f := range []func() error{
|
for _, f := range []func() error{
|
||||||
e.applyLogLevel,
|
e.applyLogLevel,
|
||||||
e.applyMark,
|
e.applyDialer,
|
||||||
e.applyInterface,
|
|
||||||
e.applyStats,
|
e.applyStats,
|
||||||
e.applyUDPTimeout,
|
e.applyUDPTimeout,
|
||||||
e.applyProxy,
|
e.applyProxy,
|
||||||
@@ -104,21 +103,15 @@ func (e *engine) applyLogLevel() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *engine) applyMark() error {
|
func (e *engine) applyDialer() error {
|
||||||
if e.Mark != 0 {
|
|
||||||
dialer.SetMark(e.Mark)
|
|
||||||
log.Infof("[DIALER] set fwmark: %#x", e.Mark)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *engine) applyInterface() error {
|
|
||||||
if e.Interface != "" {
|
if e.Interface != "" {
|
||||||
if err := dialer.BindToInterface(e.Interface); err != nil {
|
dialer.DefaultInterfaceName.Store(e.Interface)
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Infof("[DIALER] bind to interface: %s", e.Interface)
|
log.Infof("[DIALER] bind to interface: %s", e.Interface)
|
||||||
}
|
}
|
||||||
|
if e.Mark != 0 {
|
||||||
|
dialer.DefaultRoutingMark.Store(int32(e.Mark))
|
||||||
|
log.Infof("[DIALER] set fwmark: %#x", e.Mark)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user