hotfix: add idle timeout 120s for gvisor udp forwarder connection (#374)

This commit is contained in:
naison
2024-11-22 22:03:56 +08:00
committed by GitHub
parent 98c22ba9b7
commit 17a13a2672

View File

@@ -4,6 +4,7 @@ import (
"context"
"io"
"net"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
@@ -38,9 +39,9 @@ func UDPForwarder(s *stack.Stack, ctx context.Context) func(id stack.TransportEn
}
// dial dst
remote, err := net.DialUDP("udp", nil, dst)
if err != nil {
log.Errorf("[TUN-UDP] Failed to connect dst: %s: %v", dst.String(), err)
remote, err1 := net.DialUDP("udp", nil, dst)
if err1 != nil {
log.Errorf("[TUN-UDP] Failed to connect dst: %s: %v", dst.String(), err1)
return
}
@@ -52,20 +53,64 @@ func UDPForwarder(s *stack.Stack, ctx context.Context) func(id stack.TransportEn
go func() {
buf := config.LPool.Get().([]byte)[:]
defer config.LPool.Put(buf[:])
written, err2 := io.CopyBuffer(remote, conn, buf)
var written int
var err error
for {
err = conn.SetReadDeadline(time.Now().Add(time.Second * 120))
if err != nil {
break
}
var read int
read, _, err = conn.ReadFrom(buf[:])
if err != nil {
break
}
written += read
err = remote.SetWriteDeadline(time.Now().Add(time.Second * 120))
if err != nil {
break
}
_, err = remote.Write(buf[:read])
if err != nil {
break
}
}
log.Debugf("[TUN-UDP] Write length %d data from src: %s -> dst: %s", written, src.String(), dst.String())
errChan <- err2
errChan <- err
}()
go func() {
buf := config.LPool.Get().([]byte)[:]
defer config.LPool.Put(buf[:])
written, err2 := io.CopyBuffer(conn, remote, buf)
var err error
var written int
for {
err = remote.SetReadDeadline(time.Now().Add(time.Second * 120))
if err != nil {
break
}
var n int
n, _, err = remote.ReadFromUDP(buf[:])
if err != nil {
break
}
written += n
err = conn.SetWriteDeadline(time.Now().Add(time.Second * 120))
if err != nil {
break
}
_, err = conn.Write(buf[:n])
if err != nil {
break
}
}
log.Debugf("[TUN-UDP] Read length %d data from dst: %s -> src: %s", written, dst.String(), src.String())
errChan <- err2
errChan <- err
}()
err = <-errChan
if err != nil && !errors.Is(err, io.EOF) {
log.Debugf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
err1 = <-errChan
if err1 != nil && !errors.Is(err1, io.EOF) {
log.Debugf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err1)
}
}()
}).HandlePacket