From dd0cde04b4a40fdfec2bdcb0d7e96efc9bdb13bc Mon Sep 17 00:00:00 2001 From: xjasonlyu Date: Sat, 5 Feb 2022 15:49:03 +0800 Subject: [PATCH] Refactor: optimize UDP module Symmetric NAT support for now. --- core/adapter.go | 28 ++----- core/handler.go | 6 +- core/stack/tcp.go | 20 +---- core/stack/udp.go | 190 +++------------------------------------------- engine/tunnel.go | 8 +- tunnel/addr.go | 25 ++++++ tunnel/tcp.go | 23 +++--- tunnel/tunnel.go | 51 ++++--------- tunnel/udp.go | 139 ++++++++++++--------------------- tunnel/util.go | 20 ----- 10 files changed, 124 insertions(+), 386 deletions(-) create mode 100644 tunnel/addr.go delete mode 100644 tunnel/util.go diff --git a/core/adapter.go b/core/adapter.go index 98803bc..ea184c0 100644 --- a/core/adapter.go +++ b/core/adapter.go @@ -2,33 +2,15 @@ package core import ( "net" - - "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// TCPConn implements the net.Conn interface. type TCPConn interface { net.Conn - ID() *stack.TransportEndpointID } -type UDPPacket interface { - // Data get the payload of UDP Packet. - Data() []byte - - // Drop call after packet is used, could release resources in this function. - Drop() - - // ID returns the transport endpoint id of packet. - ID() *stack.TransportEndpointID - - // LocalAddr returns the source IP/Port of packet. - LocalAddr() net.Addr - - // RemoteAddr returns the destination IP/Port of packet. - RemoteAddr() net.Addr - - // WriteBack writes the payload with source IP/Port equals addr - // - variable source IP/Port is important to STUN - // - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target. - WriteBack([]byte, net.Addr) (int, error) +// UDPConn implements net.Conn and net.PacketConn. +type UDPConn interface { + net.Conn + net.PacketConn } diff --git a/core/handler.go b/core/handler.go index 8606df1..0a9345d 100644 --- a/core/handler.go +++ b/core/handler.go @@ -1,6 +1,8 @@ package core +// Handler is a TCP/UDP connection handler that implements +// HandleTCPConn and HandleUDPConn methods. type Handler interface { - Add(TCPConn) - AddPacket(UDPPacket) + HandleTCPConn(TCPConn) + HandleUDPConn(UDPConn) } diff --git a/core/stack/tcp.go b/core/stack/tcp.go index 309ad01..ad468f3 100644 --- a/core/stack/tcp.go +++ b/core/stack/tcp.go @@ -2,12 +2,10 @@ package stack import ( "fmt" - "net" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -36,10 +34,9 @@ func withTCPHandler() Option { return func(s *Stack) error { tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - id := r.ID() ep, err := r.CreateEndpoint(&wq) if err != nil { - // prevent potential half-open TCP connection leak. + // RST: prevent potential half-open TCP connection leak. r.Complete(true) return } @@ -47,11 +44,7 @@ func withTCPHandler() Option { setKeepalive(ep) - conn := &tcpConn{ - Conn: gonet.NewTCPConn(&wq, ep), - id: &id, - } - s.handler.Add(conn) + s.handler.HandleTCPConn(gonet.NewTCPConn(&wq, ep)) }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) return nil @@ -72,12 +65,3 @@ func setKeepalive(ep tcpip.Endpoint) error { } return nil } - -type tcpConn struct { - net.Conn - id *stack.TransportEndpointID -} - -func (c *tcpConn) ID() *stack.TransportEndpointID { - return c.id -} diff --git a/core/stack/udp.go b/core/stack/udp.go index 63b0b90..4a130e7 100644 --- a/core/stack/udp.go +++ b/core/stack/udp.go @@ -1,192 +1,24 @@ package stack import ( - "fmt" - "net" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -const ( - // udpNoChecksum disables UDP checksum if set to true. - udpNoChecksum = true + "gvisor.dev/gvisor/pkg/waiter" ) func withUDPHandler() Option { return func(s *Stack) error { - udpHandlePacket := func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - // Ref: gVisor pkg/tcpip/transport/udp/endpoint.go HandlePacket. - udpHdr := header.UDP(pkt.TransportHeader().View()) - if int(udpHdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize { - // Malformed packet. - s.Stats().UDP.MalformedPacketsReceived.Increment() - return true + udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) { + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + // TODO: handler errors in the future. + return } - if !verifyChecksum(udpHdr, pkt) { - // Checksum error. - s.Stats().UDP.ChecksumErrors.Increment() - return true - } - - s.Stats().UDP.PacketsReceived.Increment() - - packet := &udpPacket{ - s: s, - id: &id, - data: pkt.Data().ExtractVV(), - nicID: pkt.NICID, - netHdr: pkt.Network(), - netProto: pkt.NetworkProtocolNumber, - } - - s.handler.AddPacket(packet) - return true - } - s.SetTransportProtocolHandler(udp.ProtocolNumber, udpHandlePacket) + s.handler.HandleUDPConn(gonet.NewUDPConn(s.Stack, &wq, ep)) + }) + s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) return nil } } - -type udpPacket struct { - s *Stack - id *stack.TransportEndpointID - data buffer.VectorisedView - nicID tcpip.NICID - netHdr header.Network - netProto tcpip.NetworkProtocolNumber -} - -func (p *udpPacket) Data() []byte { - return p.data.ToView() -} - -func (p *udpPacket) Drop() {} - -func (p *udpPacket) ID() *stack.TransportEndpointID { - return p.id -} - -func (p *udpPacket) LocalAddr() net.Addr { - return &net.UDPAddr{IP: net.IP(p.id.LocalAddress), Port: int(p.id.LocalPort)} -} - -func (p *udpPacket) RemoteAddr() net.Addr { - return &net.UDPAddr{IP: net.IP(p.id.RemoteAddress), Port: int(p.id.RemotePort)} -} - -func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) { - v := buffer.View(b) - if len(v) > header.UDPMaximumPacketSize { - // Payload can't possibly fit in a packet. - return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{}) - } - - var ( - localAddress tcpip.Address - localPort uint16 - ) - - if udpAddr, ok := addr.(*net.UDPAddr); !ok { - localAddress = p.netHdr.DestinationAddress() - localPort = p.id.LocalPort - } else if ipv4 := udpAddr.IP.To4(); ipv4 != nil { - localAddress = tcpip.Address(ipv4) - localPort = uint16(udpAddr.Port) - } else { - localAddress = tcpip.Address(udpAddr.IP) - localPort = uint16(udpAddr.Port) - } - - route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */) - if err != nil { - return 0, fmt.Errorf("%#v find route: %s", p.id, err) - } - defer route.Release() - - data := v.ToVectorisedView() - if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil { - return 0, fmt.Errorf("%v", err) - } - return data.Size(), nil -} - -// sendUDP sends a UDP segment via the provided network endpoint and under the -// provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, noChecksum bool) tcpip.Error { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), - Data: data, - }) - defer pkt.DecRef() - - // Initialize the UDP header. - udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - pkt.TransportProtocolNumber = udp.ProtocolNumber - - length := uint16(pkt.Size()) - udpHdr.Encode(&header.UDPFields{ - SrcPort: localPort, - DstPort: remotePort, - Length: length, - }) - - // Set the checksum field unless TX checksum offload is enabled. - // On IPv4, UDP checksum is optional, and a zero value indicates the - // transmitter skipped the checksum generation (RFC768). - // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.RequiresTXTransportChecksum() && - (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) - for _, v := range data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) - } - - ttl := r.DefaultTTL() - - if err := r.WritePacket(stack.NetworkHeaderParams{ - Protocol: udp.ProtocolNumber, - TTL: ttl, - TOS: 0, /* default */ - }, pkt); err != nil { - r.Stats().UDP.PacketSendErrors.Increment() - return err - } - - // Track count of packets sent. - r.Stats().UDP.PacketsSent.Increment() - return nil -} - -// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go verifyChecksum. -// verifyChecksum verifies the checksum unless RX checksum offload is enabled. -func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { - if pkt.RXTransportChecksumValidated { - return true - } - - // On IPv4, UDP checksum is optional, and a zero value means the transmitter - // omitted the checksum generation, as per RFC 768: - // - // An all zero transmitted checksum value means that the transmitter - // generated no checksum (for debugging or for higher level protocols that - // don't care). - // - // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1: - // - // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP - // checksum is not optional. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 { - return true - } - - netHdr := pkt.Network() - payloadChecksum := pkt.Data().AsRange().Checksum() - return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum) -} diff --git a/engine/tunnel.go b/engine/tunnel.go index 336f5aa..627d6c3 100644 --- a/engine/tunnel.go +++ b/engine/tunnel.go @@ -9,10 +9,10 @@ var _ core.Handler = (*fakeTunnel)(nil) type fakeTunnel struct{} -func (*fakeTunnel) Add(conn core.TCPConn) { - tunnel.Add(conn) +func (*fakeTunnel) HandleTCPConn(conn core.TCPConn) { + tunnel.TCPIn() <- conn } -func (*fakeTunnel) AddPacket(packet core.UDPPacket) { - tunnel.AddPacket(packet) +func (*fakeTunnel) HandleUDPConn(conn core.UDPConn) { + tunnel.UDPIn() <- conn } diff --git a/tunnel/addr.go b/tunnel/addr.go new file mode 100644 index 0000000..bf06b46 --- /dev/null +++ b/tunnel/addr.go @@ -0,0 +1,25 @@ +package tunnel + +import ( + "net" + "strconv" +) + +// parseAddr parses net.Addr to IP and port. +func parseAddr(addr net.Addr) (net.IP, uint16) { + switch v := addr.(type) { + case *net.TCPAddr: + return v.IP, uint16(v.Port) + case *net.UDPAddr: + return v.IP, uint16(v.Port) + default: + return parseAddrString(addr.String()) + } +} + +// parseAddrString parses address string to IP and port. +func parseAddrString(addr string) (net.IP, uint16) { + host, port, _ := net.SplitHostPort(addr) + portInt, _ := strconv.ParseUint(port, 10, 16) + return net.ParseIP(host), uint16(portInt) +} diff --git a/tunnel/tcp.go b/tunnel/tcp.go index b64af7e..1c3c38e 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -22,16 +22,19 @@ func newTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn { return statistic.NewTCPTracker(conn, metadata, statistic.DefaultManager) } -func handleTCP(localConn core.TCPConn) { +func handleTCPConn(localConn core.TCPConn) { defer localConn.Close() - id := localConn.ID() + var ( + srcIP, srcPort = parseAddr(localConn.RemoteAddr()) + dstIP, dstPort = parseAddr(localConn.LocalAddr()) + ) metadata := &M.Metadata{ Net: M.TCP, - SrcIP: net.IP(id.RemoteAddress), - SrcPort: id.RemotePort, - DstIP: net.IP(id.LocalAddress), - DstPort: id.LocalPort, + SrcIP: srcIP, + SrcPort: srcPort, + DstIP: dstIP, + DstPort: dstPort, } targetConn, err := proxy.Dial(metadata) @@ -39,13 +42,7 @@ func handleTCP(localConn core.TCPConn) { log.Warnf("[TCP] dial %s error: %v", metadata.DestinationAddress(), err) return } - - if dialerAddr, ok := targetConn.LocalAddr().(*net.TCPAddr); ok { - metadata.MidIP = dialerAddr.IP - metadata.MidPort = uint16(dialerAddr.Port) - } else { /* fallback */ - metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr().String()) - } + metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr()) targetConn = newTCPTracker(targetConn, metadata) defer targetConn.Close() diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 20894e0..a8e2346 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -1,55 +1,36 @@ package tunnel import ( - "runtime" - "github.com/xjasonlyu/tun2socks/v2/core" - "github.com/xjasonlyu/tun2socks/v2/log" -) - -const ( - // maxUDPQueueSize is the max number of UDP packets - // could be buffered. if queue is full, upcoming packets - // would be dropped util queue is ready again. - maxUDPQueueSize = 1 << 9 ) +// Unbuffered TCP/UDP queues. var ( - _tcpQueue = make(chan core.TCPConn) /* unbuffered */ - _udpQueue = make(chan core.UDPPacket, maxUDPQueueSize) - _numUDPWorkers = max(runtime.GOMAXPROCS(0), 4 /* at least 4 workers */) + _tcpQueue = make(chan core.TCPConn) + _udpQueue = make(chan core.UDPConn) ) func init() { go process() } -// Add adds tcpConn to tcpQueue. -func Add(conn core.TCPConn) { - _tcpQueue <- conn +// TCPIn return fan-in TCP queue. +func TCPIn() chan<- core.TCPConn { + return _tcpQueue } -// AddPacket adds udpPacket to udpQueue. -func AddPacket(packet core.UDPPacket) { - select { - case _udpQueue <- packet: - default: - log.Warnf("queue is currently full, packet will be dropped") - packet.Drop() - } +// UDPIn return fan-in UDP queue. +func UDPIn() chan<- core.UDPConn { + return _udpQueue } func process() { - for i := 0; i < _numUDPWorkers; i++ { - queue := _udpQueue - go func() { - for packet := range queue { - handleUDP(packet) - } - }() - } - - for conn := range _tcpQueue { - go handleTCP(conn) + for { + select { + case conn := <-_tcpQueue: + go handleTCPConn(conn) + case conn := <-_udpQueue: + go handleUDPConn(conn) + } } } diff --git a/tunnel/udp.go b/tunnel/udp.go index f088cca..9017e46 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -7,7 +7,6 @@ import ( "time" "github.com/xjasonlyu/tun2socks/v2/common/pool" - "github.com/xjasonlyu/tun2socks/v2/component/nat" "github.com/xjasonlyu/tun2socks/v2/core" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" @@ -15,15 +14,8 @@ import ( "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) -var ( - // _natTable uses source udp packet information - // as key to store destination udp packetConn. - _natTable = nat.NewTable() - - // _udpSessionTimeout is the default timeout for - // each UDP session. - _udpSessionTimeout = 60 * time.Second -) +// _udpSessionTimeout is the default timeout for each UDP session. +var _udpSessionTimeout = 60 * time.Second func SetUDPTimeout(v int) { _udpSessionTimeout = time.Duration(v) * time.Second @@ -33,98 +25,58 @@ func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager) } -func handleUDP(packet core.UDPPacket) { - id := packet.ID() +func handleUDPConn(uc core.UDPConn) { + defer uc.Close() + + var ( + srcIP, srcPort = parseAddr(uc.RemoteAddr()) + dstIP, dstPort = parseAddr(uc.LocalAddr()) + ) metadata := &M.Metadata{ Net: M.UDP, - SrcIP: net.IP(id.RemoteAddress), - SrcPort: id.RemotePort, - DstIP: net.IP(id.LocalAddress), - DstPort: id.LocalPort, + SrcIP: srcIP, + SrcPort: srcPort, + DstIP: dstIP, + DstPort: dstPort, } - generateNATKey := func(m *M.Metadata) string { - return m.SourceAddress() /* as Full Cone NAT Key */ - } - key := generateNATKey(metadata) - - handle := func(drop bool) bool { - pc := _natTable.Get(key) - if pc != nil { - handleUDPToRemote(packet, pc, metadata /* as net.Addr */, drop) - return true - } - return false - } - - if handle(true /* drop */) { + pc, err := proxy.DialUDP(metadata) + if err != nil { + log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err) return } + metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr()) - lockKey := key + "-lock" - cond, loaded := _natTable.GetOrCreateLock(lockKey) - go func() { - if loaded { - cond.L.Lock() - cond.Wait() - handle(true) /* drop after sending data to remote */ - cond.L.Unlock() - return - } + pc = newUDPTracker(pc, metadata) + defer pc.Close() - defer func() { - _natTable.Delete(lockKey) - cond.Broadcast() - }() - - pc, err := proxy.DialUDP(metadata) - if err != nil { - log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err) - return - } - - if dialerAddr, ok := pc.LocalAddr().(*net.UDPAddr); ok { - metadata.MidIP = dialerAddr.IP - metadata.MidPort = uint16(dialerAddr.Port) - } else { /* fallback */ - metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr().String()) - } - - pc = newUDPTracker(pc, metadata) - - go func() { - defer pc.Close() - defer packet.Drop() - defer _natTable.Delete(key) - - handleUDPToLocal(packet, pc) - }() - - _natTable.Set(key, pc) - handle(false /* drop */) - }() + go handleUDPToRemote(uc, pc, metadata) + handleUDPToLocal(uc, pc, metadata) } -func handleUDPToRemote(packet core.UDPPacket, pc net.PacketConn, remote net.Addr, drop bool) { - defer func() { - if drop { - packet.Drop() - } - }() - - if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil { - log.Warnf("[UDP] write to %s error: %v", remote, err) - } - pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */ - - log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote) -} - -func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) { +func handleUDPToRemote(uc core.UDPConn, pc net.PacketConn, remote net.Addr) { buf := pool.Get(pool.MaxSegmentSize) defer pool.Put(buf) - for /* just loop */ { + for { + n, err := uc.Read(buf) + if err != nil { + return + } + + if _, err := pc.WriteTo(buf[:n], remote); err != nil { + log.Warnf("[UDP] write to %s error: %v", remote, err) + } + + log.Infof("[UDP] %s --> %s", uc.RemoteAddr(), remote) + } +} + +func handleUDPToLocal(uc core.UDPConn, pc net.PacketConn, remote net.Addr) { + buf := pool.Get(pool.MaxSegmentSize) + defer pool.Put(buf) + + for { pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) n, from, err := pc.ReadFrom(buf) if err != nil { @@ -134,11 +86,14 @@ func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) { return } - if _, err := packet.WriteBack(buf[:n], from); err != nil { - log.Warnf("[UDP] write back from %s error: %v", from, err) + if from.Network() != remote.Network() || from.String() != remote.String() { + log.Warnf("[UDP] drop unknown packet from %s", from) return } - log.Infof("[UDP] %s <-- %s", packet.RemoteAddr(), from) + if _, err := uc.Write(buf[:n]); err != nil { + log.Warnf("[UDP] write back from %s error: %v", from, err) + return + } } } diff --git a/tunnel/util.go b/tunnel/util.go deleted file mode 100644 index c7b7f7c..0000000 --- a/tunnel/util.go +++ /dev/null @@ -1,20 +0,0 @@ -package tunnel - -import ( - "net" - "strconv" -) - -func max(a, b int) int { - if a > b { - return a - } - return b -} - -// parseAddr parses address to IP and port. -func parseAddr(addr string) (net.IP, uint16) { - host, portStr, _ := net.SplitHostPort(addr) - portInt, _ := strconv.ParseUint(portStr, 10, 16) - return net.ParseIP(host), uint16(portInt) -}