diff --git a/pkg/core/gvisorudpforwarder.go b/pkg/core/gvisorudpforwarder.go index c795413e..4e956dd9 100644 --- a/pkg/core/gvisorudpforwarder.go +++ b/pkg/core/gvisorudpforwarder.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "time" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -38,9 +39,9 @@ func UDPForwarder(s *stack.Stack, ctx context.Context) func(id stack.TransportEn } // dial dst - remote, err := net.DialUDP("udp", nil, dst) - if err != nil { - log.Errorf("[TUN-UDP] Failed to connect dst: %s: %v", dst.String(), err) + remote, err1 := net.DialUDP("udp", nil, dst) + if err1 != nil { + log.Errorf("[TUN-UDP] Failed to connect dst: %s: %v", dst.String(), err1) return } @@ -52,20 +53,64 @@ func UDPForwarder(s *stack.Stack, ctx context.Context) func(id stack.TransportEn go func() { buf := config.LPool.Get().([]byte)[:] defer config.LPool.Put(buf[:]) - written, err2 := io.CopyBuffer(remote, conn, buf) + + var written int + var err error + for { + err = conn.SetReadDeadline(time.Now().Add(time.Second * 120)) + if err != nil { + break + } + var read int + read, _, err = conn.ReadFrom(buf[:]) + if err != nil { + break + } + written += read + err = remote.SetWriteDeadline(time.Now().Add(time.Second * 120)) + if err != nil { + break + } + _, err = remote.Write(buf[:read]) + if err != nil { + break + } + } log.Debugf("[TUN-UDP] Write length %d data from src: %s -> dst: %s", written, src.String(), dst.String()) - errChan <- err2 + errChan <- err }() go func() { buf := config.LPool.Get().([]byte)[:] defer config.LPool.Put(buf[:]) - written, err2 := io.CopyBuffer(conn, remote, buf) + + var err error + var written int + for { + err = remote.SetReadDeadline(time.Now().Add(time.Second * 120)) + if err != nil { + break + } + var n int + n, _, err = remote.ReadFromUDP(buf[:]) + if err != nil { + break + } + written += n + err = conn.SetWriteDeadline(time.Now().Add(time.Second * 120)) + if err != nil { + break + } + _, err = conn.Write(buf[:n]) + if err != nil { + break + } + } log.Debugf("[TUN-UDP] Read length %d data from dst: %s -> src: %s", written, dst.String(), src.String()) - errChan <- err2 + errChan <- err }() - err = <-errChan - if err != nil && !errors.Is(err, io.EOF) { - log.Debugf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err) + err1 = <-errChan + if err1 != nil && !errors.Is(err1, io.EOF) { + log.Debugf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err1) } }() }).HandlePacket