Improve: use interface index for macos

This commit is contained in:
xjasonlyu
2022-04-02 16:08:55 +08:00
parent 289ea82829
commit 992e716216
6 changed files with 34 additions and 166 deletions

View File

@@ -1,66 +0,0 @@
package singledo
// Ref: github.com/Dreamacro/clash/common/singledo
import (
"sync"
"time"
)
type call struct {
wg sync.WaitGroup
val any
err error
}
type Single struct {
mux sync.Mutex
last time.Time
wait time.Duration
call *call
result *Result
}
type Result struct {
Val any
Err error
}
// Do single.Do likes sync.singleFlight
//lint:ignore ST1008 it likes sync.singleFlight
func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) {
s.mux.Lock()
now := time.Now()
if now.Before(s.last.Add(s.wait)) {
s.mux.Unlock()
return s.result.Val, s.result.Err, true
}
if call := s.call; call != nil {
s.mux.Unlock()
call.wg.Wait()
return call.val, call.err, true
}
call := &call{}
call.wg.Add(1)
s.call = call
s.mux.Unlock()
call.val, call.err = fn()
call.wg.Done()
s.mux.Lock()
s.call = nil
s.result = &Result{call.val, call.err}
s.last = now
s.mux.Unlock()
return call.val, call.err, false
}
func (s *Single) Reset() {
s.last = time.Time{}
}
func NewSingle(wait time.Duration) *Single {
return &Single{wait: wait}
}

View File

@@ -1,69 +0,0 @@
package singledo
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
)
func TestBasic(t *testing.T) {
single := NewSingle(time.Millisecond * 30)
foo := 0
shardCount := atomic.NewInt32(0)
call := func() (any, error) {
foo++
time.Sleep(time.Millisecond * 5)
return nil, nil
}
var wg sync.WaitGroup
const n = 5
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
_, _, shard := single.Do(call)
if shard {
shardCount.Inc()
}
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, 1, foo)
assert.Equal(t, int32(4), shardCount.Load())
}
func TestTimer(t *testing.T) {
single := NewSingle(time.Millisecond * 30)
foo := 0
call := func() (any, error) {
foo++
return nil, nil
}
single.Do(call)
time.Sleep(10 * time.Millisecond)
_, _, shard := single.Do(call)
assert.Equal(t, 1, foo)
assert.True(t, shard)
}
func TestReset(t *testing.T) {
single := NewSingle(time.Millisecond * 30)
foo := 0
call := func() (any, error) {
foo++
return nil, nil
}
single.Do(call)
single.Reset()
single.Do(call)
assert.Equal(t, 2, foo)
}

View File

@@ -10,6 +10,7 @@ import (
var (
DefaultInterfaceName = atomic.NewString("")
DefaultInterfaceIndex = atomic.NewInt32(0)
DefaultRoutingMark = atomic.NewInt32(0)
)
@@ -19,6 +20,11 @@ type Options struct {
// from that particular interface are processed by the socket.
InterfaceName string
// InterfaceIndex is the index of interface/device to bind.
// It is almost the same as InterfaceName except it uses the
// index of the interface instead of the name.
InterfaceIndex int
// RoutingMark is the mark for each packet sent through this
// socket. Changing the mark can be used for mark-based routing
// without netfilter or for packet filtering.
@@ -28,6 +34,7 @@ type Options struct {
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return DialContextWithOptions(ctx, network, address, &Options{
InterfaceName: DefaultInterfaceName.Load(),
InterfaceIndex: int(DefaultInterfaceIndex.Load()),
RoutingMark: int(DefaultRoutingMark.Load()),
})
}
@@ -44,6 +51,7 @@ func DialContextWithOptions(ctx context.Context, network, address string, opts *
func ListenPacket(network, address string) (net.PacketConn, error) {
return ListenPacketWithOptions(network, address, &Options{
InterfaceName: DefaultInterfaceName.Load(),
InterfaceIndex: int(DefaultInterfaceIndex.Load()),
RoutingMark: int(DefaultRoutingMark.Load()),
})
}

View File

@@ -3,25 +3,10 @@ package dialer
import (
"net"
"syscall"
"time"
"github.com/xjasonlyu/tun2socks/v2/common/singledo"
"golang.org/x/sys/unix"
)
var interfaces = singledo.NewSingle(30 * time.Second)
func resolveInterfaceByName(name string) (*net.Interface, error) {
value, err, _ := interfaces.Do(func() (any, error) {
return net.InterfaceByName(name)
})
if err != nil {
return nil, err
}
return value.(*net.Interface), nil
}
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if opts == nil || !isTCPSocket(network) && !isUDPSocket(network) {
return
@@ -29,24 +14,23 @@ func setSocketOptions(network, address string, c syscall.RawConn, opts *Options)
var innerErr error
err = c.Control(func(fd uintptr) {
// must be GlobalUnicast.
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.InterfaceName != "" {
var iface *net.Interface
iface, innerErr = resolveInterfaceByName(opts.InterfaceName)
if innerErr != nil {
return
if opts.InterfaceIndex == 0 && opts.InterfaceName != "" {
if iface, err := net.InterfaceByName(opts.InterfaceName); err == nil {
opts.InterfaceIndex = iface.Index
}
}
if opts.InterfaceIndex != 0 {
switch network {
case "tcp4", "udp4":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, iface.Index)
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, opts.InterfaceIndex)
case "tcp6", "udp6":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, iface.Index)
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, opts.InterfaceIndex)
}
if innerErr != nil {
return

View File

@@ -14,12 +14,17 @@ func setSocketOptions(network, address string, c syscall.RawConn, opts *Options)
var innerErr error
err = c.Control(func(fd uintptr) {
// must be GlobalUnicast.
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.InterfaceName == "" && opts.InterfaceIndex != 0 {
if iface, err := net.InterfaceByIndex(opts.InterfaceIndex); err == nil {
opts.InterfaceName = iface.Name
}
}
if opts.InterfaceName != "" {
if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil {
return

View File

@@ -3,6 +3,7 @@ package engine
import (
"errors"
"fmt"
"net"
"sync"
"github.com/xjasonlyu/tun2socks/v2/component/dialer"
@@ -97,7 +98,12 @@ func general(k *Key) error {
log.SetLevel(level)
if k.Interface != "" {
dialer.DefaultInterfaceName.Store(k.Interface)
iface, err := net.InterfaceByName(k.Interface)
if err != nil {
return err
}
dialer.DefaultInterfaceName.Store(iface.Name)
dialer.DefaultInterfaceIndex.Store(int32(iface.Index))
log.Infof("[DIALER] bind to interface: %s", k.Interface)
}