mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-09-26 19:31:17 +08:00
refactor: no drop packet (#703)
This commit is contained in:
@@ -15,12 +15,12 @@ type bufferedTCP struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewBufferedTCP(conn net.Conn) net.Conn {
|
||||
func NewBufferedTCP(ctx context.Context, conn net.Conn) net.Conn {
|
||||
c := &bufferedTCP{
|
||||
Conn: conn,
|
||||
Chan: make(chan *DatagramPacket, MaxSize),
|
||||
}
|
||||
go c.Run()
|
||||
go c.Run(ctx)
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -38,8 +38,17 @@ func (c *bufferedTCP) Write(b []byte) (n int, err error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *bufferedTCP) Run() {
|
||||
for buf := range c.Chan {
|
||||
func (c *bufferedTCP) Run(ctx context.Context) {
|
||||
for ctx.Err() == nil {
|
||||
var buf *DatagramPacket
|
||||
select {
|
||||
case buf = <-c.Chan:
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
_, err := c.Conn.Write(buf.Data[:buf.DataLength])
|
||||
config.LPool.Put(buf.Data[:])
|
||||
if err != nil {
|
||||
@@ -50,3 +59,8 @@ func (c *bufferedTCP) Run() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *bufferedTCP) Close() error {
|
||||
c.closed = true
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
@@ -45,7 +45,7 @@ func (h *gvisorTCPHandler) handle(ctx context.Context, tcpConn net.Conn) {
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
defer util.HandleCrash()
|
||||
h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(tcpConn), endpoint)
|
||||
h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(ctx, tcpConn), endpoint)
|
||||
util.SafeClose(errChan)
|
||||
}()
|
||||
go func() {
|
||||
|
@@ -114,15 +114,12 @@ func (h *gvisorTCPHandler) readFromTCPConnWriteToEndpoint(ctx context.Context, c
|
||||
pkt.DecRef()
|
||||
plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
|
||||
} else {
|
||||
util.SafeWrite(TCPPacketChan, &Packet{
|
||||
TCPPacketChan <- &Packet{
|
||||
data: buf[:],
|
||||
length: read,
|
||||
src: src,
|
||||
dst: dst,
|
||||
}, func(v *Packet) {
|
||||
config.LPool.Put(buf[:])
|
||||
plog.G(ctx).Debugf("[TCP-TUN] Drop packet. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -74,7 +74,7 @@ type Device struct {
|
||||
|
||||
func (d *Device) readFromTun(ctx context.Context) {
|
||||
defer util.HandleCrash()
|
||||
for {
|
||||
for ctx.Err() == nil {
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
n, err := d.tun.Read(buf[:])
|
||||
if err != nil {
|
||||
@@ -92,16 +92,22 @@ func (d *Device) readFromTun(ctx context.Context) {
|
||||
}
|
||||
|
||||
plog.G(ctx).Debugf("[TUN] SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
||||
util.SafeWrite(d.tunInbound, NewPacket(buf[:], n, src, dst), func(v *Packet) {
|
||||
config.LPool.Put(v.data[:])
|
||||
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
|
||||
})
|
||||
d.tunInbound <- NewPacket(buf[:], n, src, dst)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Device) writeToTun(ctx context.Context) {
|
||||
defer util.HandleCrash()
|
||||
for packet := range d.tunOutbound {
|
||||
for ctx.Err() == nil {
|
||||
var packet *Packet
|
||||
select {
|
||||
case packet = <-d.tunOutbound:
|
||||
if packet == nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
_, err := d.tun.Write(packet.data[1:packet.length])
|
||||
config.LPool.Put(packet.data[:])
|
||||
if err != nil {
|
||||
@@ -114,9 +120,6 @@ func (d *Device) writeToTun(ctx context.Context) {
|
||||
|
||||
func (d *Device) Close() {
|
||||
d.tun.Close()
|
||||
util.SafeClose(d.tunInbound)
|
||||
util.SafeClose(d.tunOutbound)
|
||||
util.SafeClose(TCPPacketChan)
|
||||
}
|
||||
|
||||
func (d *Device) handlePacket(ctx context.Context, routeMapTCP *sync.Map) {
|
||||
@@ -183,14 +186,32 @@ func (p *Peer) sendErr(err error) {
|
||||
|
||||
func (p *Peer) routeTCPToTun(ctx context.Context) {
|
||||
defer util.HandleCrash()
|
||||
for packet := range TCPPacketChan {
|
||||
for ctx.Err() == nil {
|
||||
var packet *Packet
|
||||
select {
|
||||
case packet = <-TCPPacketChan:
|
||||
if packet == nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
p.tunOutbound <- packet
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) routeTun(ctx context.Context) {
|
||||
defer util.HandleCrash()
|
||||
for packet := range p.tunInbound {
|
||||
for ctx.Err() == nil {
|
||||
var packet *Packet
|
||||
select {
|
||||
case packet = <-p.tunInbound:
|
||||
if packet == nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok {
|
||||
plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr())
|
||||
copy(packet.data[1:packet.length+1], packet.data[:packet.length])
|
||||
|
@@ -77,7 +77,7 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
|
||||
defer util.HandleCrash()
|
||||
var gvisorInbound = make(chan *Packet, MaxSize)
|
||||
go handleGvisorPacket(gvisorInbound, tunInbound).Run(ctx)
|
||||
for {
|
||||
for ctx.Err() == nil {
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime))
|
||||
if err != nil {
|
||||
@@ -99,22 +99,25 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
|
||||
continue
|
||||
}
|
||||
if buf[0] == 1 {
|
||||
util.SafeWrite(gvisorInbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) {
|
||||
config.LPool.Put(v.data[:])
|
||||
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
|
||||
})
|
||||
gvisorInbound <- NewPacket(buf[:], n, nil, nil)
|
||||
} else {
|
||||
util.SafeWrite(tunOutbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) {
|
||||
config.LPool.Put(v.data[:])
|
||||
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
|
||||
})
|
||||
tunOutbound <- NewPacket(buf[:], n, nil, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeToConn(ctx context.Context, conn net.Conn, inbound <-chan *Packet, errChan chan error) {
|
||||
defer util.HandleCrash()
|
||||
for packet := range inbound {
|
||||
for ctx.Err() == nil {
|
||||
var packet *Packet
|
||||
select {
|
||||
case packet = <-inbound:
|
||||
if packet == nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime))
|
||||
if err != nil {
|
||||
plog.G(ctx).Errorf("Failed to set write deadline: %v", err)
|
||||
@@ -135,7 +138,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
|
||||
defer util.HandleCrash()
|
||||
var gvisorInbound = make(chan *Packet, MaxSize)
|
||||
go handleGvisorPacket(gvisorInbound, d.tunOutbound).Run(ctx)
|
||||
for {
|
||||
for ctx.Err() == nil {
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
n, err := d.tun.Read(buf[1:])
|
||||
if err != nil {
|
||||
@@ -158,21 +161,27 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
|
||||
}
|
||||
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
||||
packet := NewPacket(buf[:], n+1, src, dst)
|
||||
f := func(v *Packet) {
|
||||
config.LPool.Put(v.data[:])
|
||||
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
|
||||
}
|
||||
if packet.src.Equal(packet.dst) {
|
||||
util.SafeWrite(gvisorInbound, packet, f)
|
||||
gvisorInbound <- packet
|
||||
} else {
|
||||
util.SafeWrite(d.tunInbound, packet, f)
|
||||
d.tunInbound <- packet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *ClientDevice) writeToTun(ctx context.Context) {
|
||||
defer util.HandleCrash()
|
||||
for packet := range d.tunOutbound {
|
||||
for ctx.Err() == nil {
|
||||
var packet *Packet
|
||||
select {
|
||||
case packet = <-d.tunOutbound:
|
||||
if packet == nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
_, err := d.tun.Write(packet.data[1:packet.length])
|
||||
config.LPool.Put(packet.data[:])
|
||||
if err != nil {
|
||||
@@ -185,8 +194,6 @@ func (d *ClientDevice) writeToTun(ctx context.Context) {
|
||||
|
||||
func (d *ClientDevice) Close() {
|
||||
d.tun.Close()
|
||||
util.SafeClose(d.tunInbound)
|
||||
util.SafeClose(d.tunOutbound)
|
||||
}
|
||||
|
||||
func (d *ClientDevice) heartbeats(ctx context.Context) {
|
||||
@@ -214,10 +221,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
|
||||
data := config.LPool.Get().([]byte)[:]
|
||||
length := copy(data[1:], bytes)
|
||||
data[0] = 1
|
||||
util.SafeWrite(d.tunInbound, &Packet{
|
||||
d.tunInbound <- &Packet{
|
||||
data: data[:],
|
||||
length: length + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if srcIPv6 != nil {
|
||||
@@ -228,10 +235,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
|
||||
data := config.LPool.Get().([]byte)[:]
|
||||
length := copy(data[1:], bytes6)
|
||||
data[0] = 1
|
||||
util.SafeWrite(d.tunInbound, &Packet{
|
||||
d.tunInbound <- &Packet{
|
||||
data: data[:],
|
||||
length: length + 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user