diff --git a/engine/engine.go b/engine/engine.go index f513fac..2ad28fc 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -16,7 +16,6 @@ import ( "github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/option" "github.com/xjasonlyu/tun2socks/v2/dialer" - "github.com/xjasonlyu/tun2socks/v2/engine/mirror" "github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/restapi" @@ -130,7 +129,7 @@ func general(k *Key) error { if k.UDPTimeout < time.Second { return errors.New("invalid udp timeout value") } - tunnel.SetUDPTimeout(k.UDPTimeout) + tunnel.T().SetUDPTimeout(k.UDPTimeout) } return nil } @@ -192,7 +191,7 @@ func netstack(k *Key) (err error) { if _defaultProxy, err = parseProxy(k.Proxy); err != nil { return } - proxy.SetDialer(_defaultProxy) + tunnel.T().SetDialer(_defaultProxy) if _defaultDevice, err = parseDevice(k.Device, uint32(k.MTU)); err != nil { return @@ -226,7 +225,7 @@ func netstack(k *Key) (err error) { if _defaultStack, err = core.CreateStack(&core.Config{ LinkEndpoint: _defaultDevice, - TransportHandler: &mirror.Tunnel{}, + TransportHandler: tunnel.T(), MulticastGroups: multicastGroups, Options: opts, }); err != nil { diff --git a/engine/mirror/tunnel.go b/engine/mirror/tunnel.go deleted file mode 100644 index b753ac3..0000000 --- a/engine/mirror/tunnel.go +++ /dev/null @@ -1,18 +0,0 @@ -package mirror - -import ( - "github.com/xjasonlyu/tun2socks/v2/core/adapter" - "github.com/xjasonlyu/tun2socks/v2/tunnel" -) - -var _ adapter.TransportHandler = (*Tunnel)(nil) - -type Tunnel struct{} - -func (*Tunnel) HandleTCP(conn adapter.TCPConn) { - tunnel.TCPIn() <- conn -} - -func (*Tunnel) HandleUDP(conn adapter.UDPConn) { - tunnel.UDPIn() <- conn -} diff --git a/tunnel/global.go b/tunnel/global.go new file mode 100644 index 0000000..28e8a00 --- /dev/null +++ b/tunnel/global.go @@ -0,0 +1,37 @@ +package tunnel + +import ( + "sync" + + "github.com/xjasonlyu/tun2socks/v2/proxy" + "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" +) + +var ( + _globalMu sync.RWMutex + _globalT *Tunnel +) + +func init() { + ReplaceGlobal(New(&proxy.Base{}, statistic.DefaultManager)) + T().ProcessAsync() +} + +// T returns the global Tunnel, which can be reconfigured with +// ReplaceGlobal. It's safe for concurrent use. +func T() *Tunnel { + _globalMu.RLock() + t := _globalT + _globalMu.RUnlock() + return t +} + +// ReplaceGlobal replaces the global Tunnel, and returns a function +// to restore the original values. It's safe for concurrent use. +func ReplaceGlobal(t *Tunnel) func() { + _globalMu.Lock() + prev := _globalT + _globalT = t + _globalMu.Unlock() + return func() { ReplaceGlobal(prev) } +} diff --git a/tunnel/statistic/manager.go b/tunnel/statistic/manager.go index d9ff978..5842fc5 100644 --- a/tunnel/statistic/manager.go +++ b/tunnel/statistic/manager.go @@ -18,7 +18,6 @@ func init() { uploadTotal: atomic.NewInt64(0), downloadTotal: atomic.NewInt64(0), } - go DefaultManager.handle() } diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 393c200..ee37761 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -50,11 +50,6 @@ func NewTCPTracker(conn net.Conn, metadata *M.Metadata, manager *Manager) net.Co return tt } -// DefaultTCPTracker returns a new net.Conn(*tcpTacker) with default manager. -func DefaultTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn { - return NewTCPTracker(conn, metadata, DefaultManager) -} - func (tt *tcpTracker) ID() string { return tt.UUID.String() } @@ -120,11 +115,6 @@ func NewUDPTracker(conn net.PacketConn, metadata *M.Metadata, manager *Manager) return ut } -// DefaultUDPTracker returns a new net.PacketConn(*udpTacker) with default manager. -func DefaultUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { - return NewUDPTracker(conn, metadata, DefaultManager) -} - func (ut *udpTracker) ID() string { return ut.UUID.String() } diff --git a/tunnel/tcp.go b/tunnel/tcp.go index 03cebab..90e6235 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -1,6 +1,7 @@ package tunnel import ( + "context" "io" "net" "sync" @@ -10,16 +11,10 @@ import ( "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" - "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) -const ( - // tcpWaitTimeout implements a TCP half-close timeout. - tcpWaitTimeout = 60 * time.Second -) - -func handleTCPConn(originConn adapter.TCPConn) { +func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) { defer originConn.Close() id := originConn.ID() @@ -31,21 +26,24 @@ func handleTCPConn(originConn adapter.TCPConn) { DstPort: id.LocalPort, } - remoteConn, err := proxy.Dial(metadata) + ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) + defer cancel() + + remoteConn, err := t.Dialer().DialContext(ctx, metadata) if err != nil { log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err) return } metadata.MidIP, metadata.MidPort = parseAddr(remoteConn.LocalAddr()) - remoteConn = statistic.DefaultTCPTracker(remoteConn, metadata) + remoteConn = statistic.NewTCPTracker(remoteConn, metadata, t.manager) defer remoteConn.Close() log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) pipe(originConn, remoteConn) } -// pipe copies copy data to & from provided net.Conn(s) bidirectionally. +// pipe copies data to & from provided net.Conn(s) bidirectionally. func pipe(origin, remote net.Conn) { wg := sync.WaitGroup{} wg.Add(2) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 8ced53f..9f69d5d 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -1,36 +1,116 @@ package tunnel import ( + "context" + "sync" + "time" + + "go.uber.org/atomic" + "github.com/xjasonlyu/tun2socks/v2/core/adapter" + "github.com/xjasonlyu/tun2socks/v2/proxy" + "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) -// Unbuffered TCP/UDP queues. -var ( - _tcpQueue = make(chan adapter.TCPConn) - _udpQueue = make(chan adapter.UDPConn) +const ( + // tcpConnectTimeout is the default timeout for TCP handshakes. + tcpConnectTimeout = 5 * time.Second + // tcpWaitTimeout implements a TCP half-close timeout. + tcpWaitTimeout = 60 * time.Second + // udpSessionTimeout is the default timeout for UDP sessions. + udpSessionTimeout = 60 * time.Second ) -func init() { - go process() +var _ adapter.TransportHandler = (*Tunnel)(nil) + +type Tunnel struct { + // Unbuffered TCP/UDP queues. + tcpQueue chan adapter.TCPConn + udpQueue chan adapter.UDPConn + + // UDP session timeout. + udpTimeout *atomic.Duration + + // Internal proxy.Dialer for Tunnel. + dialerMu sync.RWMutex + dialer proxy.Dialer + + // Where the Tunnel statistics are sent to. + manager *statistic.Manager + + procOnce sync.Once + procCancel context.CancelFunc +} + +func New(dialer proxy.Dialer, manager *statistic.Manager) *Tunnel { + return &Tunnel{ + tcpQueue: make(chan adapter.TCPConn), + udpQueue: make(chan adapter.UDPConn), + udpTimeout: atomic.NewDuration(udpSessionTimeout), + dialer: dialer, + manager: manager, + procCancel: func() { /* nop */ }, + } } // TCPIn return fan-in TCP queue. -func TCPIn() chan<- adapter.TCPConn { - return _tcpQueue +func (t *Tunnel) TCPIn() chan<- adapter.TCPConn { + return t.tcpQueue } // UDPIn return fan-in UDP queue. -func UDPIn() chan<- adapter.UDPConn { - return _udpQueue +func (t *Tunnel) UDPIn() chan<- adapter.UDPConn { + return t.udpQueue } -func process() { +func (t *Tunnel) HandleTCP(conn adapter.TCPConn) { + t.TCPIn() <- conn +} + +func (t *Tunnel) HandleUDP(conn adapter.UDPConn) { + t.UDPIn() <- conn +} + +func (t *Tunnel) process(ctx context.Context) { for { select { - case conn := <-_tcpQueue: - go handleTCPConn(conn) - case conn := <-_udpQueue: - go handleUDPConn(conn) + case conn := <-t.tcpQueue: + go t.handleTCPConn(conn) + case conn := <-t.udpQueue: + go t.handleUDPConn(conn) + case <-ctx.Done(): + return } } } + +// ProcessAsync can be safely called multiple times, but will only be effective once. +func (t *Tunnel) ProcessAsync() { + t.procOnce.Do(func() { + ctx, cancel := context.WithCancel(context.Background()) + t.procCancel = cancel + go t.process(ctx) + }) +} + +// Close closes the Tunnel and releases its resources. +func (t *Tunnel) Close() { + t.procCancel() +} + +func (t *Tunnel) Dialer() proxy.Dialer { + t.dialerMu.RLock() + d := t.dialer + t.dialerMu.RUnlock() + return d +} + +func (t *Tunnel) SetDialer(dialer proxy.Dialer) { + t.dialerMu.Lock() + t.dialer = dialer + t.dialerMu.Unlock() +} + +func (t *Tunnel) SetUDPTimeout(timeout time.Duration) { + t.udpTimeout.Store(timeout) +} diff --git a/tunnel/udp.go b/tunnel/udp.go index a10e2d4..f92af64 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -10,19 +10,11 @@ import ( "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" - "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) -// _udpSessionTimeout is the default timeout for each UDP session. -var _udpSessionTimeout = 60 * time.Second - -func SetUDPTimeout(t time.Duration) { - _udpSessionTimeout = t -} - // TODO: Port Restricted NAT support. -func handleUDPConn(uc adapter.UDPConn) { +func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) { defer uc.Close() id := uc.ID() @@ -34,14 +26,14 @@ func handleUDPConn(uc adapter.UDPConn) { DstPort: id.LocalPort, } - pc, err := proxy.DialUDP(metadata) + pc, err := t.Dialer().DialUDP(metadata) if err != nil { log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) return } metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr()) - pc = statistic.DefaultUDPTracker(pc, metadata) + pc = statistic.NewUDPTracker(pc, metadata, t.manager) defer pc.Close() var remote net.Addr @@ -53,22 +45,22 @@ func handleUDPConn(uc adapter.UDPConn) { pc = newSymmetricNATPacketConn(pc, metadata) log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) - pipePacket(uc, pc, remote) + pipePacket(uc, pc, remote, t.udpTimeout.Load()) } -func pipePacket(origin, remote net.PacketConn, to net.Addr) { +func pipePacket(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) { wg := sync.WaitGroup{} wg.Add(2) - go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg) - go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg) + go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg, timeout) + go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg, timeout) wg.Wait() } -func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup) { +func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) { defer wg.Done() - if err := copyPacketData(dst, src, to, _udpSessionTimeout); err != nil { + if err := copyPacketData(dst, src, to, timeout); err != nil { log.Debugf("[UDP] copy data for %s: %v", dir, err) } }