diff --git a/tunnel/udp.go b/tunnel/udp.go index 7ea28cd..80d74f7 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -1,9 +1,8 @@ package tunnel import ( - "errors" "net" - "os" + "sync" "time" "github.com/xjasonlyu/tun2socks/v2/common/pool" @@ -55,53 +54,48 @@ func handleUDPConn(uc adapter.UDPConn) { remote = metadata.Addr() } - go handleUDPToRemote(uc, pc, remote) - handleUDPToLocal(uc, pc, remote) + log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) + relayPacket(uc, pc, remote) } -func handleUDPToRemote(uc adapter.UDPConn, pc net.PacketConn, remote net.Addr) { +func relayPacket(left net.PacketConn, right net.PacketConn, to net.Addr) { + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + if err := copyPacketBuffer(right, left, to, _udpSessionTimeout); err != nil { + log.Warnf("[UDP] copy packet buffer: %v", err) + } + }() + + go func() { + defer wg.Done() + if err := copyPacketBuffer(left, right, nil, _udpSessionTimeout); err != nil { + log.Warnf("[UDP] copy packet buffer: %v", err) + } + }() + + wg.Wait() +} + +func copyPacketBuffer(dst net.PacketConn, src net.PacketConn, to net.Addr, timeout time.Duration) error { buf := pool.Get(pool.MaxSegmentSize) defer pool.Put(buf) for { - n, err := uc.Read(buf) + src.SetReadDeadline(time.Now().Add(timeout)) + n, _, err := src.ReadFrom(buf) if err != nil { - return - } - - if _, err := pc.WriteTo(buf[:n], remote); err != nil { - log.Warnf("[UDP] write to %s: %v", remote, err) - } - pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */ - - log.Infof("[UDP] %s --> %s", uc.RemoteAddr(), remote) - } -} - -func handleUDPToLocal(uc adapter.UDPConn, pc net.PacketConn, remote net.Addr) { - buf := pool.Get(pool.MaxSegmentSize) - defer pool.Put(buf) - - for { - pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) - n, from, err := pc.ReadFrom(buf) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) /* ignore I/O timeout */ { - log.Warnf("[UDP] read from %s: %v", pc.LocalAddr(), err) + if ne, ok := err.(net.Error); ok && ne.Timeout() { + return nil /* ignore I/O timeout */ } - return + return err } - if from == nil || from.Network() != remote.Network() || from.String() != remote.String() { - log.Warnf("[UDP] drop unknown packet from %v", from) - return + if _, err = dst.WriteTo(buf[:n], to); err != nil { + return err } - - if _, err := uc.Write(buf[:n]); err != nil { - log.Warnf("[UDP] write back from %s: %v", from, err) - return - } - - log.Infof("[UDP] %s <-- %s", uc.RemoteAddr(), from) + dst.SetReadDeadline(time.Now().Add(timeout)) } }