From 4df63d1642db8a9eb5d8ecaeaceac2cbfb9b1980 Mon Sep 17 00:00:00 2001 From: naison <895703375@qq.com> Date: Thu, 14 Aug 2025 19:13:17 +0800 Subject: [PATCH] refactor: no drop packet (#703) --- pkg/core/bufferedtcp.go | 22 +++++++++++--- pkg/core/gvisortcphandler.go | 2 +- pkg/core/gvisortunendpoint.go | 7 ++--- pkg/core/tunhandler.go | 43 ++++++++++++++++++++------- pkg/core/tunhandlerclient.go | 55 ++++++++++++++++++++--------------- 5 files changed, 84 insertions(+), 45 deletions(-) diff --git a/pkg/core/bufferedtcp.go b/pkg/core/bufferedtcp.go index cbcc21fd..1e904e3a 100644 --- a/pkg/core/bufferedtcp.go +++ b/pkg/core/bufferedtcp.go @@ -15,12 +15,12 @@ type bufferedTCP struct { closed bool } -func NewBufferedTCP(conn net.Conn) net.Conn { +func NewBufferedTCP(ctx context.Context, conn net.Conn) net.Conn { c := &bufferedTCP{ Conn: conn, Chan: make(chan *DatagramPacket, MaxSize), } - go c.Run() + go c.Run(ctx) return c } @@ -38,8 +38,17 @@ func (c *bufferedTCP) Write(b []byte) (n int, err error) { return n, nil } -func (c *bufferedTCP) Run() { - for buf := range c.Chan { +func (c *bufferedTCP) Run(ctx context.Context) { + for ctx.Err() == nil { + var buf *DatagramPacket + select { + case buf = <-c.Chan: + if buf == nil { + return + } + case <-ctx.Done(): + return + } _, err := c.Conn.Write(buf.Data[:buf.DataLength]) config.LPool.Put(buf.Data[:]) if err != nil { @@ -50,3 +59,8 @@ func (c *bufferedTCP) Run() { } } } + +func (c *bufferedTCP) Close() error { + c.closed = true + return c.Conn.Close() +} diff --git a/pkg/core/gvisortcphandler.go b/pkg/core/gvisortcphandler.go index cd1e1a27..3fadee2f 100644 --- a/pkg/core/gvisortcphandler.go +++ b/pkg/core/gvisortcphandler.go @@ -45,7 +45,7 @@ func (h *gvisorTCPHandler) handle(ctx context.Context, tcpConn net.Conn) { errChan := make(chan error, 2) go func() { defer util.HandleCrash() - h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(tcpConn), endpoint) + h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(ctx, tcpConn), endpoint) util.SafeClose(errChan) }() go func() { diff --git a/pkg/core/gvisortunendpoint.go b/pkg/core/gvisortunendpoint.go index b61f4770..2774b3ba 100755 --- a/pkg/core/gvisortunendpoint.go +++ b/pkg/core/gvisortunendpoint.go @@ -114,15 +114,12 @@ func (h *gvisorTCPHandler) readFromTCPConnWriteToEndpoint(ctx context.Context, c pkt.DecRef() plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read) } else { - util.SafeWrite(TCPPacketChan, &Packet{ + TCPPacketChan <- &Packet{ data: buf[:], length: read, src: src, dst: dst, - }, func(v *Packet) { - config.LPool.Put(buf[:]) - plog.G(ctx).Debugf("[TCP-TUN] Drop packet. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read) - }) + } } } } diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index 3a051082..6db66b53 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -74,7 +74,7 @@ type Device struct { func (d *Device) readFromTun(ctx context.Context) { defer util.HandleCrash() - for { + for ctx.Err() == nil { buf := config.LPool.Get().([]byte)[:] n, err := d.tun.Read(buf[:]) if err != nil { @@ -92,16 +92,22 @@ func (d *Device) readFromTun(ctx context.Context) { } plog.G(ctx).Debugf("[TUN] SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) - util.SafeWrite(d.tunInbound, NewPacket(buf[:], n, src, dst), func(v *Packet) { - config.LPool.Put(v.data[:]) - plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length) - }) + d.tunInbound <- NewPacket(buf[:], n, src, dst) } } func (d *Device) writeToTun(ctx context.Context) { defer util.HandleCrash() - for packet := range d.tunOutbound { + for ctx.Err() == nil { + var packet *Packet + select { + case packet = <-d.tunOutbound: + if packet == nil { + return + } + case <-ctx.Done(): + return + } _, err := d.tun.Write(packet.data[1:packet.length]) config.LPool.Put(packet.data[:]) if err != nil { @@ -114,9 +120,6 @@ func (d *Device) writeToTun(ctx context.Context) { func (d *Device) Close() { d.tun.Close() - util.SafeClose(d.tunInbound) - util.SafeClose(d.tunOutbound) - util.SafeClose(TCPPacketChan) } func (d *Device) handlePacket(ctx context.Context, routeMapTCP *sync.Map) { @@ -183,14 +186,32 @@ func (p *Peer) sendErr(err error) { func (p *Peer) routeTCPToTun(ctx context.Context) { defer util.HandleCrash() - for packet := range TCPPacketChan { + for ctx.Err() == nil { + var packet *Packet + select { + case packet = <-TCPPacketChan: + if packet == nil { + return + } + case <-ctx.Done(): + return + } p.tunOutbound <- packet } } func (p *Peer) routeTun(ctx context.Context) { defer util.HandleCrash() - for packet := range p.tunInbound { + for ctx.Err() == nil { + var packet *Packet + select { + case packet = <-p.tunInbound: + if packet == nil { + return + } + case <-ctx.Done(): + return + } if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr()) copy(packet.data[1:packet.length+1], packet.data[:packet.length]) diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index 6d22ab41..a5efb598 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -77,7 +77,7 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t defer util.HandleCrash() var gvisorInbound = make(chan *Packet, MaxSize) go handleGvisorPacket(gvisorInbound, tunInbound).Run(ctx) - for { + for ctx.Err() == nil { buf := config.LPool.Get().([]byte)[:] err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime)) if err != nil { @@ -99,22 +99,25 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t continue } if buf[0] == 1 { - util.SafeWrite(gvisorInbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) { - config.LPool.Put(v.data[:]) - plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length) - }) + gvisorInbound <- NewPacket(buf[:], n, nil, nil) } else { - util.SafeWrite(tunOutbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) { - config.LPool.Put(v.data[:]) - plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length) - }) + tunOutbound <- NewPacket(buf[:], n, nil, nil) } } } func writeToConn(ctx context.Context, conn net.Conn, inbound <-chan *Packet, errChan chan error) { defer util.HandleCrash() - for packet := range inbound { + for ctx.Err() == nil { + var packet *Packet + select { + case packet = <-inbound: + if packet == nil { + return + } + case <-ctx.Done(): + return + } err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime)) if err != nil { plog.G(ctx).Errorf("Failed to set write deadline: %v", err) @@ -135,7 +138,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) { defer util.HandleCrash() var gvisorInbound = make(chan *Packet, MaxSize) go handleGvisorPacket(gvisorInbound, d.tunOutbound).Run(ctx) - for { + for ctx.Err() == nil { buf := config.LPool.Get().([]byte)[:] n, err := d.tun.Read(buf[1:]) if err != nil { @@ -158,21 +161,27 @@ func (d *ClientDevice) readFromTun(ctx context.Context) { } plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) packet := NewPacket(buf[:], n+1, src, dst) - f := func(v *Packet) { - config.LPool.Put(v.data[:]) - plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length) - } if packet.src.Equal(packet.dst) { - util.SafeWrite(gvisorInbound, packet, f) + gvisorInbound <- packet } else { - util.SafeWrite(d.tunInbound, packet, f) + d.tunInbound <- packet } } } func (d *ClientDevice) writeToTun(ctx context.Context) { defer util.HandleCrash() - for packet := range d.tunOutbound { + for ctx.Err() == nil { + var packet *Packet + select { + case packet = <-d.tunOutbound: + if packet == nil { + return + } + case <-ctx.Done(): + return + } + _, err := d.tun.Write(packet.data[1:packet.length]) config.LPool.Put(packet.data[:]) if err != nil { @@ -185,8 +194,6 @@ func (d *ClientDevice) writeToTun(ctx context.Context) { func (d *ClientDevice) Close() { d.tun.Close() - util.SafeClose(d.tunInbound) - util.SafeClose(d.tunOutbound) } func (d *ClientDevice) heartbeats(ctx context.Context) { @@ -214,10 +221,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) { data := config.LPool.Get().([]byte)[:] length := copy(data[1:], bytes) data[0] = 1 - util.SafeWrite(d.tunInbound, &Packet{ + d.tunInbound <- &Packet{ data: data[:], length: length + 1, - }) + } } } if srcIPv6 != nil { @@ -228,10 +235,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) { data := config.LPool.Get().([]byte)[:] length := copy(data[1:], bytes6) data[0] = 1 - util.SafeWrite(d.tunInbound, &Packet{ + d.tunInbound <- &Packet{ data: data[:], length: length + 1, - }) + } } }