Refactor(tunnel): modularize tunnel pkg (#393)

This commit is contained in:
Jason Lyu
2024-08-31 11:31:18 +08:00
committed by GitHub
parent 71c45ef87e
commit fd98f65994
8 changed files with 152 additions and 75 deletions

View File

@@ -16,7 +16,6 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/device"
"github.com/xjasonlyu/tun2socks/v2/core/option" "github.com/xjasonlyu/tun2socks/v2/core/option"
"github.com/xjasonlyu/tun2socks/v2/dialer" "github.com/xjasonlyu/tun2socks/v2/dialer"
"github.com/xjasonlyu/tun2socks/v2/engine/mirror"
"github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/log"
"github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/restapi" "github.com/xjasonlyu/tun2socks/v2/restapi"
@@ -130,7 +129,7 @@ func general(k *Key) error {
if k.UDPTimeout < time.Second { if k.UDPTimeout < time.Second {
return errors.New("invalid udp timeout value") return errors.New("invalid udp timeout value")
} }
tunnel.SetUDPTimeout(k.UDPTimeout) tunnel.T().SetUDPTimeout(k.UDPTimeout)
} }
return nil return nil
} }
@@ -192,7 +191,7 @@ func netstack(k *Key) (err error) {
if _defaultProxy, err = parseProxy(k.Proxy); err != nil { if _defaultProxy, err = parseProxy(k.Proxy); err != nil {
return return
} }
proxy.SetDialer(_defaultProxy) tunnel.T().SetDialer(_defaultProxy)
if _defaultDevice, err = parseDevice(k.Device, uint32(k.MTU)); err != nil { if _defaultDevice, err = parseDevice(k.Device, uint32(k.MTU)); err != nil {
return return
@@ -226,7 +225,7 @@ func netstack(k *Key) (err error) {
if _defaultStack, err = core.CreateStack(&core.Config{ if _defaultStack, err = core.CreateStack(&core.Config{
LinkEndpoint: _defaultDevice, LinkEndpoint: _defaultDevice,
TransportHandler: &mirror.Tunnel{}, TransportHandler: tunnel.T(),
MulticastGroups: multicastGroups, MulticastGroups: multicastGroups,
Options: opts, Options: opts,
}); err != nil { }); err != nil {

View File

@@ -1,18 +0,0 @@
package mirror
import (
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/tunnel"
)
var _ adapter.TransportHandler = (*Tunnel)(nil)
type Tunnel struct{}
func (*Tunnel) HandleTCP(conn adapter.TCPConn) {
tunnel.TCPIn() <- conn
}
func (*Tunnel) HandleUDP(conn adapter.UDPConn) {
tunnel.UDPIn() <- conn
}

37
tunnel/global.go Normal file
View File

@@ -0,0 +1,37 @@
package tunnel
import (
"sync"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)
var (
_globalMu sync.RWMutex
_globalT *Tunnel
)
func init() {
ReplaceGlobal(New(&proxy.Base{}, statistic.DefaultManager))
T().ProcessAsync()
}
// T returns the global Tunnel, which can be reconfigured with
// ReplaceGlobal. It's safe for concurrent use.
func T() *Tunnel {
_globalMu.RLock()
t := _globalT
_globalMu.RUnlock()
return t
}
// ReplaceGlobal replaces the global Tunnel, and returns a function
// to restore the original values. It's safe for concurrent use.
func ReplaceGlobal(t *Tunnel) func() {
_globalMu.Lock()
prev := _globalT
_globalT = t
_globalMu.Unlock()
return func() { ReplaceGlobal(prev) }
}

View File

@@ -18,7 +18,6 @@ func init() {
uploadTotal: atomic.NewInt64(0), uploadTotal: atomic.NewInt64(0),
downloadTotal: atomic.NewInt64(0), downloadTotal: atomic.NewInt64(0),
} }
go DefaultManager.handle() go DefaultManager.handle()
} }

View File

@@ -50,11 +50,6 @@ func NewTCPTracker(conn net.Conn, metadata *M.Metadata, manager *Manager) net.Co
return tt return tt
} }
// DefaultTCPTracker returns a new net.Conn(*tcpTacker) with default manager.
func DefaultTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn {
return NewTCPTracker(conn, metadata, DefaultManager)
}
func (tt *tcpTracker) ID() string { func (tt *tcpTracker) ID() string {
return tt.UUID.String() return tt.UUID.String()
} }
@@ -120,11 +115,6 @@ func NewUDPTracker(conn net.PacketConn, metadata *M.Metadata, manager *Manager)
return ut return ut
} }
// DefaultUDPTracker returns a new net.PacketConn(*udpTacker) with default manager.
func DefaultUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
return NewUDPTracker(conn, metadata, DefaultManager)
}
func (ut *udpTracker) ID() string { func (ut *udpTracker) ID() string {
return ut.UUID.String() return ut.UUID.String()
} }

View File

@@ -1,6 +1,7 @@
package tunnel package tunnel
import ( import (
"context"
"io" "io"
"net" "net"
"sync" "sync"
@@ -10,16 +11,10 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata" M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
) )
const ( func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) {
// tcpWaitTimeout implements a TCP half-close timeout.
tcpWaitTimeout = 60 * time.Second
)
func handleTCPConn(originConn adapter.TCPConn) {
defer originConn.Close() defer originConn.Close()
id := originConn.ID() id := originConn.ID()
@@ -31,21 +26,24 @@ func handleTCPConn(originConn adapter.TCPConn) {
DstPort: id.LocalPort, DstPort: id.LocalPort,
} }
remoteConn, err := proxy.Dial(metadata) ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
remoteConn, err := t.Dialer().DialContext(ctx, metadata)
if err != nil { if err != nil {
log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err) log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err)
return return
} }
metadata.MidIP, metadata.MidPort = parseAddr(remoteConn.LocalAddr()) metadata.MidIP, metadata.MidPort = parseAddr(remoteConn.LocalAddr())
remoteConn = statistic.DefaultTCPTracker(remoteConn, metadata) remoteConn = statistic.NewTCPTracker(remoteConn, metadata, t.manager)
defer remoteConn.Close() defer remoteConn.Close()
log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
pipe(originConn, remoteConn) pipe(originConn, remoteConn)
} }
// pipe copies copy data to & from provided net.Conn(s) bidirectionally. // pipe copies data to & from provided net.Conn(s) bidirectionally.
func pipe(origin, remote net.Conn) { func pipe(origin, remote net.Conn) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)

View File

@@ -1,36 +1,116 @@
package tunnel package tunnel
import ( import (
"context"
"sync"
"time"
"go.uber.org/atomic"
"github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
) )
const (
// tcpConnectTimeout is the default timeout for TCP handshakes.
tcpConnectTimeout = 5 * time.Second
// tcpWaitTimeout implements a TCP half-close timeout.
tcpWaitTimeout = 60 * time.Second
// udpSessionTimeout is the default timeout for UDP sessions.
udpSessionTimeout = 60 * time.Second
)
var _ adapter.TransportHandler = (*Tunnel)(nil)
type Tunnel struct {
// Unbuffered TCP/UDP queues. // Unbuffered TCP/UDP queues.
var ( tcpQueue chan adapter.TCPConn
_tcpQueue = make(chan adapter.TCPConn) udpQueue chan adapter.UDPConn
_udpQueue = make(chan adapter.UDPConn)
)
func init() { // UDP session timeout.
go process() udpTimeout *atomic.Duration
// Internal proxy.Dialer for Tunnel.
dialerMu sync.RWMutex
dialer proxy.Dialer
// Where the Tunnel statistics are sent to.
manager *statistic.Manager
procOnce sync.Once
procCancel context.CancelFunc
}
func New(dialer proxy.Dialer, manager *statistic.Manager) *Tunnel {
return &Tunnel{
tcpQueue: make(chan adapter.TCPConn),
udpQueue: make(chan adapter.UDPConn),
udpTimeout: atomic.NewDuration(udpSessionTimeout),
dialer: dialer,
manager: manager,
procCancel: func() { /* nop */ },
}
} }
// TCPIn return fan-in TCP queue. // TCPIn return fan-in TCP queue.
func TCPIn() chan<- adapter.TCPConn { func (t *Tunnel) TCPIn() chan<- adapter.TCPConn {
return _tcpQueue return t.tcpQueue
} }
// UDPIn return fan-in UDP queue. // UDPIn return fan-in UDP queue.
func UDPIn() chan<- adapter.UDPConn { func (t *Tunnel) UDPIn() chan<- adapter.UDPConn {
return _udpQueue return t.udpQueue
} }
func process() { func (t *Tunnel) HandleTCP(conn adapter.TCPConn) {
t.TCPIn() <- conn
}
func (t *Tunnel) HandleUDP(conn adapter.UDPConn) {
t.UDPIn() <- conn
}
func (t *Tunnel) process(ctx context.Context) {
for { for {
select { select {
case conn := <-_tcpQueue: case conn := <-t.tcpQueue:
go handleTCPConn(conn) go t.handleTCPConn(conn)
case conn := <-_udpQueue: case conn := <-t.udpQueue:
go handleUDPConn(conn) go t.handleUDPConn(conn)
case <-ctx.Done():
return
} }
} }
} }
// ProcessAsync can be safely called multiple times, but will only be effective once.
func (t *Tunnel) ProcessAsync() {
t.procOnce.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
t.procCancel = cancel
go t.process(ctx)
})
}
// Close closes the Tunnel and releases its resources.
func (t *Tunnel) Close() {
t.procCancel()
}
func (t *Tunnel) Dialer() proxy.Dialer {
t.dialerMu.RLock()
d := t.dialer
t.dialerMu.RUnlock()
return d
}
func (t *Tunnel) SetDialer(dialer proxy.Dialer) {
t.dialerMu.Lock()
t.dialer = dialer
t.dialerMu.Unlock()
}
func (t *Tunnel) SetUDPTimeout(timeout time.Duration) {
t.udpTimeout.Store(timeout)
}

View File

@@ -10,19 +10,11 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata" M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
) )
// _udpSessionTimeout is the default timeout for each UDP session.
var _udpSessionTimeout = 60 * time.Second
func SetUDPTimeout(t time.Duration) {
_udpSessionTimeout = t
}
// TODO: Port Restricted NAT support. // TODO: Port Restricted NAT support.
func handleUDPConn(uc adapter.UDPConn) { func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) {
defer uc.Close() defer uc.Close()
id := uc.ID() id := uc.ID()
@@ -34,14 +26,14 @@ func handleUDPConn(uc adapter.UDPConn) {
DstPort: id.LocalPort, DstPort: id.LocalPort,
} }
pc, err := proxy.DialUDP(metadata) pc, err := t.Dialer().DialUDP(metadata)
if err != nil { if err != nil {
log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err)
return return
} }
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr()) metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr())
pc = statistic.DefaultUDPTracker(pc, metadata) pc = statistic.NewUDPTracker(pc, metadata, t.manager)
defer pc.Close() defer pc.Close()
var remote net.Addr var remote net.Addr
@@ -53,22 +45,22 @@ func handleUDPConn(uc adapter.UDPConn) {
pc = newSymmetricNATPacketConn(pc, metadata) pc = newSymmetricNATPacketConn(pc, metadata)
log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
pipePacket(uc, pc, remote) pipePacket(uc, pc, remote, t.udpTimeout.Load())
} }
func pipePacket(origin, remote net.PacketConn, to net.Addr) { func pipePacket(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(2) wg.Add(2)
go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg) go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg, timeout)
go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg) go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg, timeout)
wg.Wait() wg.Wait()
} }
func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup) { func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done() defer wg.Done()
if err := copyPacketData(dst, src, to, _udpSessionTimeout); err != nil { if err := copyPacketData(dst, src, to, timeout); err != nil {
log.Debugf("[UDP] copy data for %s: %v", dir, err) log.Debugf("[UDP] copy data for %s: %v", dir, err)
} }
} }