package tunnel import ( "net" "sync" "time" "github.com/xjasonlyu/tun2socks/v2/common/pool" "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) // _udpSessionTimeout is the default timeout for each UDP session. var _udpSessionTimeout = 60 * time.Second func SetUDPTimeout(t time.Duration) { _udpSessionTimeout = t } func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager) } // TODO: Port Restricted NAT support. func handleUDPConn(uc adapter.UDPConn) { defer uc.Close() id := uc.ID() metadata := &M.Metadata{ Network: M.UDP, SrcIP: net.IP(id.RemoteAddress), SrcPort: id.RemotePort, DstIP: net.IP(id.LocalAddress), DstPort: id.LocalPort, } pc, err := proxy.DialUDP(metadata) if err != nil { log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) return } metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr()) pc = newUDPTracker(pc, metadata) defer pc.Close() var remote net.Addr if udpAddr := metadata.UDPAddr(); udpAddr != nil { remote = udpAddr } else { remote = metadata.Addr() } log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) relayPacket(uc, pc, remote) } func relayPacket(left net.PacketConn, right net.PacketConn, to net.Addr) { wg := sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() if err := copyPacketBuffer(right, left, to, _udpSessionTimeout); err != nil { log.Warnf("[UDP] %v", err) } }() go func() { defer wg.Done() if err := copyPacketBuffer(left, right, nil, _udpSessionTimeout); err != nil { log.Warnf("[UDP] %v", err) } }() wg.Wait() } func copyPacketBuffer(dst net.PacketConn, src net.PacketConn, to net.Addr, timeout time.Duration) error { buf := pool.Get(pool.MaxSegmentSize) defer pool.Put(buf) for { src.SetReadDeadline(time.Now().Add(timeout)) n, _, err := src.ReadFrom(buf) if ne, ok := err.(net.Error); ok && ne.Timeout() { return nil /* ignore I/O timeout */ } else if err != nil { return err } if _, err = dst.WriteTo(buf[:n], to); err != nil { return err } dst.SetReadDeadline(time.Now().Add(timeout)) } }