refactor: no drop packet (#703)

This commit is contained in:
naison
2025-08-14 19:13:17 +08:00
committed by GitHub
parent 4ddba64737
commit 4df63d1642
5 changed files with 84 additions and 45 deletions

View File

@@ -15,12 +15,12 @@ type bufferedTCP struct {
closed bool
}
func NewBufferedTCP(conn net.Conn) net.Conn {
func NewBufferedTCP(ctx context.Context, conn net.Conn) net.Conn {
c := &bufferedTCP{
Conn: conn,
Chan: make(chan *DatagramPacket, MaxSize),
}
go c.Run()
go c.Run(ctx)
return c
}
@@ -38,8 +38,17 @@ func (c *bufferedTCP) Write(b []byte) (n int, err error) {
return n, nil
}
func (c *bufferedTCP) Run() {
for buf := range c.Chan {
func (c *bufferedTCP) Run(ctx context.Context) {
for ctx.Err() == nil {
var buf *DatagramPacket
select {
case buf = <-c.Chan:
if buf == nil {
return
}
case <-ctx.Done():
return
}
_, err := c.Conn.Write(buf.Data[:buf.DataLength])
config.LPool.Put(buf.Data[:])
if err != nil {
@@ -50,3 +59,8 @@ func (c *bufferedTCP) Run() {
}
}
}
func (c *bufferedTCP) Close() error {
c.closed = true
return c.Conn.Close()
}

View File

@@ -45,7 +45,7 @@ func (h *gvisorTCPHandler) handle(ctx context.Context, tcpConn net.Conn) {
errChan := make(chan error, 2)
go func() {
defer util.HandleCrash()
h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(tcpConn), endpoint)
h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(ctx, tcpConn), endpoint)
util.SafeClose(errChan)
}()
go func() {

View File

@@ -114,15 +114,12 @@ func (h *gvisorTCPHandler) readFromTCPConnWriteToEndpoint(ctx context.Context, c
pkt.DecRef()
plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
} else {
util.SafeWrite(TCPPacketChan, &Packet{
TCPPacketChan <- &Packet{
data: buf[:],
length: read,
src: src,
dst: dst,
}, func(v *Packet) {
config.LPool.Put(buf[:])
plog.G(ctx).Debugf("[TCP-TUN] Drop packet. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
})
}
}
}
}

View File

@@ -74,7 +74,7 @@ type Device struct {
func (d *Device) readFromTun(ctx context.Context) {
defer util.HandleCrash()
for {
for ctx.Err() == nil {
buf := config.LPool.Get().([]byte)[:]
n, err := d.tun.Read(buf[:])
if err != nil {
@@ -92,16 +92,22 @@ func (d *Device) readFromTun(ctx context.Context) {
}
plog.G(ctx).Debugf("[TUN] SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
util.SafeWrite(d.tunInbound, NewPacket(buf[:], n, src, dst), func(v *Packet) {
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
})
d.tunInbound <- NewPacket(buf[:], n, src, dst)
}
}
func (d *Device) writeToTun(ctx context.Context) {
defer util.HandleCrash()
for packet := range d.tunOutbound {
for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-d.tunOutbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
_, err := d.tun.Write(packet.data[1:packet.length])
config.LPool.Put(packet.data[:])
if err != nil {
@@ -114,9 +120,6 @@ func (d *Device) writeToTun(ctx context.Context) {
func (d *Device) Close() {
d.tun.Close()
util.SafeClose(d.tunInbound)
util.SafeClose(d.tunOutbound)
util.SafeClose(TCPPacketChan)
}
func (d *Device) handlePacket(ctx context.Context, routeMapTCP *sync.Map) {
@@ -183,14 +186,32 @@ func (p *Peer) sendErr(err error) {
func (p *Peer) routeTCPToTun(ctx context.Context) {
defer util.HandleCrash()
for packet := range TCPPacketChan {
for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-TCPPacketChan:
if packet == nil {
return
}
case <-ctx.Done():
return
}
p.tunOutbound <- packet
}
}
func (p *Peer) routeTun(ctx context.Context) {
defer util.HandleCrash()
for packet := range p.tunInbound {
for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-p.tunInbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok {
plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr())
copy(packet.data[1:packet.length+1], packet.data[:packet.length])

View File

@@ -77,7 +77,7 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
defer util.HandleCrash()
var gvisorInbound = make(chan *Packet, MaxSize)
go handleGvisorPacket(gvisorInbound, tunInbound).Run(ctx)
for {
for ctx.Err() == nil {
buf := config.LPool.Get().([]byte)[:]
err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime))
if err != nil {
@@ -99,22 +99,25 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
continue
}
if buf[0] == 1 {
util.SafeWrite(gvisorInbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) {
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
})
gvisorInbound <- NewPacket(buf[:], n, nil, nil)
} else {
util.SafeWrite(tunOutbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) {
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
})
tunOutbound <- NewPacket(buf[:], n, nil, nil)
}
}
}
func writeToConn(ctx context.Context, conn net.Conn, inbound <-chan *Packet, errChan chan error) {
defer util.HandleCrash()
for packet := range inbound {
for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-inbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime))
if err != nil {
plog.G(ctx).Errorf("Failed to set write deadline: %v", err)
@@ -135,7 +138,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
defer util.HandleCrash()
var gvisorInbound = make(chan *Packet, MaxSize)
go handleGvisorPacket(gvisorInbound, d.tunOutbound).Run(ctx)
for {
for ctx.Err() == nil {
buf := config.LPool.Get().([]byte)[:]
n, err := d.tun.Read(buf[1:])
if err != nil {
@@ -158,21 +161,27 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
}
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
packet := NewPacket(buf[:], n+1, src, dst)
f := func(v *Packet) {
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
}
if packet.src.Equal(packet.dst) {
util.SafeWrite(gvisorInbound, packet, f)
gvisorInbound <- packet
} else {
util.SafeWrite(d.tunInbound, packet, f)
d.tunInbound <- packet
}
}
}
func (d *ClientDevice) writeToTun(ctx context.Context) {
defer util.HandleCrash()
for packet := range d.tunOutbound {
for ctx.Err() == nil {
var packet *Packet
select {
case packet = <-d.tunOutbound:
if packet == nil {
return
}
case <-ctx.Done():
return
}
_, err := d.tun.Write(packet.data[1:packet.length])
config.LPool.Put(packet.data[:])
if err != nil {
@@ -185,8 +194,6 @@ func (d *ClientDevice) writeToTun(ctx context.Context) {
func (d *ClientDevice) Close() {
d.tun.Close()
util.SafeClose(d.tunInbound)
util.SafeClose(d.tunOutbound)
}
func (d *ClientDevice) heartbeats(ctx context.Context) {
@@ -214,10 +221,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
data := config.LPool.Get().([]byte)[:]
length := copy(data[1:], bytes)
data[0] = 1
util.SafeWrite(d.tunInbound, &Packet{
d.tunInbound <- &Packet{
data: data[:],
length: length + 1,
})
}
}
}
if srcIPv6 != nil {
@@ -228,10 +235,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
data := config.LPool.Get().([]byte)[:]
length := copy(data[1:], bytes6)
data[0] = 1
util.SafeWrite(d.tunInbound, &Packet{
d.tunInbound <- &Packet{
data: data[:],
length: length + 1,
})
}
}
}