refactor: no drop packet (#703)

This commit is contained in:
naison
2025-08-14 19:13:17 +08:00
committed by GitHub
parent 4ddba64737
commit 4df63d1642
5 changed files with 84 additions and 45 deletions

View File

@@ -15,12 +15,12 @@ type bufferedTCP struct {
closed bool closed bool
} }
func NewBufferedTCP(conn net.Conn) net.Conn { func NewBufferedTCP(ctx context.Context, conn net.Conn) net.Conn {
c := &bufferedTCP{ c := &bufferedTCP{
Conn: conn, Conn: conn,
Chan: make(chan *DatagramPacket, MaxSize), Chan: make(chan *DatagramPacket, MaxSize),
} }
go c.Run() go c.Run(ctx)
return c return c
} }
@@ -38,8 +38,17 @@ func (c *bufferedTCP) Write(b []byte) (n int, err error) {
return n, nil return n, nil
} }
func (c *bufferedTCP) Run() { func (c *bufferedTCP) Run(ctx context.Context) {
for buf := range c.Chan { for ctx.Err() == nil {
var buf *DatagramPacket
select {
case buf = <-c.Chan:
if buf == nil {
return
}
case <-ctx.Done():
return
}
_, err := c.Conn.Write(buf.Data[:buf.DataLength]) _, err := c.Conn.Write(buf.Data[:buf.DataLength])
config.LPool.Put(buf.Data[:]) config.LPool.Put(buf.Data[:])
if err != nil { if err != nil {
@@ -50,3 +59,8 @@ func (c *bufferedTCP) Run() {
} }
} }
} }
func (c *bufferedTCP) Close() error {
c.closed = true
return c.Conn.Close()
}

View File

@@ -45,7 +45,7 @@ func (h *gvisorTCPHandler) handle(ctx context.Context, tcpConn net.Conn) {
errChan := make(chan error, 2) errChan := make(chan error, 2)
go func() { go func() {
defer util.HandleCrash() defer util.HandleCrash()
h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(tcpConn), endpoint) h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(ctx, tcpConn), endpoint)
util.SafeClose(errChan) util.SafeClose(errChan)
}() }()
go func() { go func() {

View File

@@ -114,15 +114,12 @@ func (h *gvisorTCPHandler) readFromTCPConnWriteToEndpoint(ctx context.Context, c
pkt.DecRef() pkt.DecRef()
plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read) plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
} else { } else {
util.SafeWrite(TCPPacketChan, &Packet{ TCPPacketChan <- &Packet{
data: buf[:], data: buf[:],
length: read, length: read,
src: src, src: src,
dst: dst, dst: dst,
}, func(v *Packet) { }
config.LPool.Put(buf[:])
plog.G(ctx).Debugf("[TCP-TUN] Drop packet. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
})
} }
} }
} }

View File

@@ -74,7 +74,7 @@ type Device struct {
func (d *Device) readFromTun(ctx context.Context) { func (d *Device) readFromTun(ctx context.Context) {
defer util.HandleCrash() defer util.HandleCrash()
for { for ctx.Err() == nil {
buf := config.LPool.Get().([]byte)[:] buf := config.LPool.Get().([]byte)[:]
n, err := d.tun.Read(buf[:]) n, err := d.tun.Read(buf[:])
if err != nil { if err != nil {
@@ -92,16 +92,22 @@ func (d *Device) readFromTun(ctx context.Context) {
} }
plog.G(ctx).Debugf("[TUN] SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) plog.G(ctx).Debugf("[TUN] SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
util.SafeWrite(d.tunInbound, NewPacket(buf[:], n, src, dst), func(v *Packet) { d.tunInbound <- NewPacket(buf[:], n, src, dst)
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
})
} }
} }
func (d *Device) writeToTun(ctx context.Context) { func (d *Device) writeToTun(ctx context.Context) {
defer util.HandleCrash() defer util.HandleCrash()
for packet := range d.tunOutbound { for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-d.tunOutbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
_, err := d.tun.Write(packet.data[1:packet.length]) _, err := d.tun.Write(packet.data[1:packet.length])
config.LPool.Put(packet.data[:]) config.LPool.Put(packet.data[:])
if err != nil { if err != nil {
@@ -114,9 +120,6 @@ func (d *Device) writeToTun(ctx context.Context) {
func (d *Device) Close() { func (d *Device) Close() {
d.tun.Close() d.tun.Close()
util.SafeClose(d.tunInbound)
util.SafeClose(d.tunOutbound)
util.SafeClose(TCPPacketChan)
} }
func (d *Device) handlePacket(ctx context.Context, routeMapTCP *sync.Map) { func (d *Device) handlePacket(ctx context.Context, routeMapTCP *sync.Map) {
@@ -183,14 +186,32 @@ func (p *Peer) sendErr(err error) {
func (p *Peer) routeTCPToTun(ctx context.Context) { func (p *Peer) routeTCPToTun(ctx context.Context) {
defer util.HandleCrash() defer util.HandleCrash()
for packet := range TCPPacketChan { for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-TCPPacketChan:
if packet == nil {
return
}
case <-ctx.Done():
return
}
p.tunOutbound <- packet p.tunOutbound <- packet
} }
} }
func (p *Peer) routeTun(ctx context.Context) { func (p *Peer) routeTun(ctx context.Context) {
defer util.HandleCrash() defer util.HandleCrash()
for packet := range p.tunInbound { for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-p.tunInbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok {
plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr()) plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr())
copy(packet.data[1:packet.length+1], packet.data[:packet.length]) copy(packet.data[1:packet.length+1], packet.data[:packet.length])

View File

@@ -77,7 +77,7 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
defer util.HandleCrash() defer util.HandleCrash()
var gvisorInbound = make(chan *Packet, MaxSize) var gvisorInbound = make(chan *Packet, MaxSize)
go handleGvisorPacket(gvisorInbound, tunInbound).Run(ctx) go handleGvisorPacket(gvisorInbound, tunInbound).Run(ctx)
for { for ctx.Err() == nil {
buf := config.LPool.Get().([]byte)[:] buf := config.LPool.Get().([]byte)[:]
err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime)) err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime))
if err != nil { if err != nil {
@@ -99,22 +99,25 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
continue continue
} }
if buf[0] == 1 { if buf[0] == 1 {
util.SafeWrite(gvisorInbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) { gvisorInbound <- NewPacket(buf[:], n, nil, nil)
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
})
} else { } else {
util.SafeWrite(tunOutbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) { tunOutbound <- NewPacket(buf[:], n, nil, nil)
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
})
} }
} }
} }
func writeToConn(ctx context.Context, conn net.Conn, inbound <-chan *Packet, errChan chan error) { func writeToConn(ctx context.Context, conn net.Conn, inbound <-chan *Packet, errChan chan error) {
defer util.HandleCrash() defer util.HandleCrash()
for packet := range inbound { for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-inbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime)) err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime))
if err != nil { if err != nil {
plog.G(ctx).Errorf("Failed to set write deadline: %v", err) plog.G(ctx).Errorf("Failed to set write deadline: %v", err)
@@ -135,7 +138,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
defer util.HandleCrash() defer util.HandleCrash()
var gvisorInbound = make(chan *Packet, MaxSize) var gvisorInbound = make(chan *Packet, MaxSize)
go handleGvisorPacket(gvisorInbound, d.tunOutbound).Run(ctx) go handleGvisorPacket(gvisorInbound, d.tunOutbound).Run(ctx)
for { for ctx.Err() == nil {
buf := config.LPool.Get().([]byte)[:] buf := config.LPool.Get().([]byte)[:]
n, err := d.tun.Read(buf[1:]) n, err := d.tun.Read(buf[1:])
if err != nil { if err != nil {
@@ -158,21 +161,27 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
} }
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
packet := NewPacket(buf[:], n+1, src, dst) packet := NewPacket(buf[:], n+1, src, dst)
f := func(v *Packet) {
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
}
if packet.src.Equal(packet.dst) { if packet.src.Equal(packet.dst) {
util.SafeWrite(gvisorInbound, packet, f) gvisorInbound <- packet
} else { } else {
util.SafeWrite(d.tunInbound, packet, f) d.tunInbound <- packet
} }
} }
} }
func (d *ClientDevice) writeToTun(ctx context.Context) { func (d *ClientDevice) writeToTun(ctx context.Context) {
defer util.HandleCrash() defer util.HandleCrash()
for packet := range d.tunOutbound { for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-d.tunOutbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
_, err := d.tun.Write(packet.data[1:packet.length]) _, err := d.tun.Write(packet.data[1:packet.length])
config.LPool.Put(packet.data[:]) config.LPool.Put(packet.data[:])
if err != nil { if err != nil {
@@ -185,8 +194,6 @@ func (d *ClientDevice) writeToTun(ctx context.Context) {
func (d *ClientDevice) Close() { func (d *ClientDevice) Close() {
d.tun.Close() d.tun.Close()
util.SafeClose(d.tunInbound)
util.SafeClose(d.tunOutbound)
} }
func (d *ClientDevice) heartbeats(ctx context.Context) { func (d *ClientDevice) heartbeats(ctx context.Context) {
@@ -214,10 +221,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
data := config.LPool.Get().([]byte)[:] data := config.LPool.Get().([]byte)[:]
length := copy(data[1:], bytes) length := copy(data[1:], bytes)
data[0] = 1 data[0] = 1
util.SafeWrite(d.tunInbound, &Packet{ d.tunInbound <- &Packet{
data: data[:], data: data[:],
length: length + 1, length: length + 1,
}) }
} }
} }
if srcIPv6 != nil { if srcIPv6 != nil {
@@ -228,10 +235,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
data := config.LPool.Get().([]byte)[:] data := config.LPool.Get().([]byte)[:]
length := copy(data[1:], bytes6) length := copy(data[1:], bytes6)
data[0] = 1 data[0] = 1
util.SafeWrite(d.tunInbound, &Packet{ d.tunInbound <- &Packet{
data: data[:], data: data[:],
length: length + 1, length: length + 1,
}) }
} }
} }