mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-09-27 03:36:09 +08:00
refactor: no drop packet (#703)
This commit is contained in:
@@ -15,12 +15,12 @@ type bufferedTCP struct {
|
|||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBufferedTCP(conn net.Conn) net.Conn {
|
func NewBufferedTCP(ctx context.Context, conn net.Conn) net.Conn {
|
||||||
c := &bufferedTCP{
|
c := &bufferedTCP{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
Chan: make(chan *DatagramPacket, MaxSize),
|
Chan: make(chan *DatagramPacket, MaxSize),
|
||||||
}
|
}
|
||||||
go c.Run()
|
go c.Run(ctx)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,8 +38,17 @@ func (c *bufferedTCP) Write(b []byte) (n int, err error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *bufferedTCP) Run() {
|
func (c *bufferedTCP) Run(ctx context.Context) {
|
||||||
for buf := range c.Chan {
|
for ctx.Err() == nil {
|
||||||
|
var buf *DatagramPacket
|
||||||
|
select {
|
||||||
|
case buf = <-c.Chan:
|
||||||
|
if buf == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
_, err := c.Conn.Write(buf.Data[:buf.DataLength])
|
_, err := c.Conn.Write(buf.Data[:buf.DataLength])
|
||||||
config.LPool.Put(buf.Data[:])
|
config.LPool.Put(buf.Data[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -50,3 +59,8 @@ func (c *bufferedTCP) Run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *bufferedTCP) Close() error {
|
||||||
|
c.closed = true
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
@@ -45,7 +45,7 @@ func (h *gvisorTCPHandler) handle(ctx context.Context, tcpConn net.Conn) {
|
|||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
go func() {
|
go func() {
|
||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(tcpConn), endpoint)
|
h.readFromTCPConnWriteToEndpoint(ctx, NewBufferedTCP(ctx, tcpConn), endpoint)
|
||||||
util.SafeClose(errChan)
|
util.SafeClose(errChan)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
|
@@ -114,15 +114,12 @@ func (h *gvisorTCPHandler) readFromTCPConnWriteToEndpoint(ctx context.Context, c
|
|||||||
pkt.DecRef()
|
pkt.DecRef()
|
||||||
plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
|
plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
|
||||||
} else {
|
} else {
|
||||||
util.SafeWrite(TCPPacketChan, &Packet{
|
TCPPacketChan <- &Packet{
|
||||||
data: buf[:],
|
data: buf[:],
|
||||||
length: read,
|
length: read,
|
||||||
src: src,
|
src: src,
|
||||||
dst: dst,
|
dst: dst,
|
||||||
}, func(v *Packet) {
|
}
|
||||||
config.LPool.Put(buf[:])
|
|
||||||
plog.G(ctx).Debugf("[TCP-TUN] Drop packet. SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(ipProtocol).String(), read)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -74,7 +74,7 @@ type Device struct {
|
|||||||
|
|
||||||
func (d *Device) readFromTun(ctx context.Context) {
|
func (d *Device) readFromTun(ctx context.Context) {
|
||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
for {
|
for ctx.Err() == nil {
|
||||||
buf := config.LPool.Get().([]byte)[:]
|
buf := config.LPool.Get().([]byte)[:]
|
||||||
n, err := d.tun.Read(buf[:])
|
n, err := d.tun.Read(buf[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -92,16 +92,22 @@ func (d *Device) readFromTun(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
plog.G(ctx).Debugf("[TUN] SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
plog.G(ctx).Debugf("[TUN] SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
||||||
util.SafeWrite(d.tunInbound, NewPacket(buf[:], n, src, dst), func(v *Packet) {
|
d.tunInbound <- NewPacket(buf[:], n, src, dst)
|
||||||
config.LPool.Put(v.data[:])
|
|
||||||
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Device) writeToTun(ctx context.Context) {
|
func (d *Device) writeToTun(ctx context.Context) {
|
||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
for packet := range d.tunOutbound {
|
for ctx.Err() == nil {
|
||||||
|
var packet *Packet
|
||||||
|
select {
|
||||||
|
case packet = <-d.tunOutbound:
|
||||||
|
if packet == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
_, err := d.tun.Write(packet.data[1:packet.length])
|
_, err := d.tun.Write(packet.data[1:packet.length])
|
||||||
config.LPool.Put(packet.data[:])
|
config.LPool.Put(packet.data[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -114,9 +120,6 @@ func (d *Device) writeToTun(ctx context.Context) {
|
|||||||
|
|
||||||
func (d *Device) Close() {
|
func (d *Device) Close() {
|
||||||
d.tun.Close()
|
d.tun.Close()
|
||||||
util.SafeClose(d.tunInbound)
|
|
||||||
util.SafeClose(d.tunOutbound)
|
|
||||||
util.SafeClose(TCPPacketChan)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Device) handlePacket(ctx context.Context, routeMapTCP *sync.Map) {
|
func (d *Device) handlePacket(ctx context.Context, routeMapTCP *sync.Map) {
|
||||||
@@ -183,14 +186,32 @@ func (p *Peer) sendErr(err error) {
|
|||||||
|
|
||||||
func (p *Peer) routeTCPToTun(ctx context.Context) {
|
func (p *Peer) routeTCPToTun(ctx context.Context) {
|
||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
for packet := range TCPPacketChan {
|
for ctx.Err() == nil {
|
||||||
|
var packet *Packet
|
||||||
|
select {
|
||||||
|
case packet = <-TCPPacketChan:
|
||||||
|
if packet == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
p.tunOutbound <- packet
|
p.tunOutbound <- packet
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) routeTun(ctx context.Context) {
|
func (p *Peer) routeTun(ctx context.Context) {
|
||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
for packet := range p.tunInbound {
|
for ctx.Err() == nil {
|
||||||
|
var packet *Packet
|
||||||
|
select {
|
||||||
|
case packet = <-p.tunInbound:
|
||||||
|
if packet == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok {
|
if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok {
|
||||||
plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr())
|
plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr())
|
||||||
copy(packet.data[1:packet.length+1], packet.data[:packet.length])
|
copy(packet.data[1:packet.length+1], packet.data[:packet.length])
|
||||||
|
@@ -77,7 +77,7 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
|
|||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
var gvisorInbound = make(chan *Packet, MaxSize)
|
var gvisorInbound = make(chan *Packet, MaxSize)
|
||||||
go handleGvisorPacket(gvisorInbound, tunInbound).Run(ctx)
|
go handleGvisorPacket(gvisorInbound, tunInbound).Run(ctx)
|
||||||
for {
|
for ctx.Err() == nil {
|
||||||
buf := config.LPool.Get().([]byte)[:]
|
buf := config.LPool.Get().([]byte)[:]
|
||||||
err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime))
|
err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,22 +99,25 @@ func readFromConn(ctx context.Context, conn net.Conn, tunInbound chan *Packet, t
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if buf[0] == 1 {
|
if buf[0] == 1 {
|
||||||
util.SafeWrite(gvisorInbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) {
|
gvisorInbound <- NewPacket(buf[:], n, nil, nil)
|
||||||
config.LPool.Put(v.data[:])
|
|
||||||
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
util.SafeWrite(tunOutbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) {
|
tunOutbound <- NewPacket(buf[:], n, nil, nil)
|
||||||
config.LPool.Put(v.data[:])
|
|
||||||
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeToConn(ctx context.Context, conn net.Conn, inbound <-chan *Packet, errChan chan error) {
|
func writeToConn(ctx context.Context, conn net.Conn, inbound <-chan *Packet, errChan chan error) {
|
||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
for packet := range inbound {
|
for ctx.Err() == nil {
|
||||||
|
var packet *Packet
|
||||||
|
select {
|
||||||
|
case packet = <-inbound:
|
||||||
|
if packet == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime))
|
err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
plog.G(ctx).Errorf("Failed to set write deadline: %v", err)
|
plog.G(ctx).Errorf("Failed to set write deadline: %v", err)
|
||||||
@@ -135,7 +138,7 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
|
|||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
var gvisorInbound = make(chan *Packet, MaxSize)
|
var gvisorInbound = make(chan *Packet, MaxSize)
|
||||||
go handleGvisorPacket(gvisorInbound, d.tunOutbound).Run(ctx)
|
go handleGvisorPacket(gvisorInbound, d.tunOutbound).Run(ctx)
|
||||||
for {
|
for ctx.Err() == nil {
|
||||||
buf := config.LPool.Get().([]byte)[:]
|
buf := config.LPool.Get().([]byte)[:]
|
||||||
n, err := d.tun.Read(buf[1:])
|
n, err := d.tun.Read(buf[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -158,21 +161,27 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
|
||||||
packet := NewPacket(buf[:], n+1, src, dst)
|
packet := NewPacket(buf[:], n+1, src, dst)
|
||||||
f := func(v *Packet) {
|
|
||||||
config.LPool.Put(v.data[:])
|
|
||||||
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
|
|
||||||
}
|
|
||||||
if packet.src.Equal(packet.dst) {
|
if packet.src.Equal(packet.dst) {
|
||||||
util.SafeWrite(gvisorInbound, packet, f)
|
gvisorInbound <- packet
|
||||||
} else {
|
} else {
|
||||||
util.SafeWrite(d.tunInbound, packet, f)
|
d.tunInbound <- packet
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *ClientDevice) writeToTun(ctx context.Context) {
|
func (d *ClientDevice) writeToTun(ctx context.Context) {
|
||||||
defer util.HandleCrash()
|
defer util.HandleCrash()
|
||||||
for packet := range d.tunOutbound {
|
for ctx.Err() == nil {
|
||||||
|
var packet *Packet
|
||||||
|
select {
|
||||||
|
case packet = <-d.tunOutbound:
|
||||||
|
if packet == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
_, err := d.tun.Write(packet.data[1:packet.length])
|
_, err := d.tun.Write(packet.data[1:packet.length])
|
||||||
config.LPool.Put(packet.data[:])
|
config.LPool.Put(packet.data[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -185,8 +194,6 @@ func (d *ClientDevice) writeToTun(ctx context.Context) {
|
|||||||
|
|
||||||
func (d *ClientDevice) Close() {
|
func (d *ClientDevice) Close() {
|
||||||
d.tun.Close()
|
d.tun.Close()
|
||||||
util.SafeClose(d.tunInbound)
|
|
||||||
util.SafeClose(d.tunOutbound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *ClientDevice) heartbeats(ctx context.Context) {
|
func (d *ClientDevice) heartbeats(ctx context.Context) {
|
||||||
@@ -214,10 +221,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
|
|||||||
data := config.LPool.Get().([]byte)[:]
|
data := config.LPool.Get().([]byte)[:]
|
||||||
length := copy(data[1:], bytes)
|
length := copy(data[1:], bytes)
|
||||||
data[0] = 1
|
data[0] = 1
|
||||||
util.SafeWrite(d.tunInbound, &Packet{
|
d.tunInbound <- &Packet{
|
||||||
data: data[:],
|
data: data[:],
|
||||||
length: length + 1,
|
length: length + 1,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if srcIPv6 != nil {
|
if srcIPv6 != nil {
|
||||||
@@ -228,10 +235,10 @@ func (d *ClientDevice) heartbeats(ctx context.Context) {
|
|||||||
data := config.LPool.Get().([]byte)[:]
|
data := config.LPool.Get().([]byte)[:]
|
||||||
length := copy(data[1:], bytes6)
|
length := copy(data[1:], bytes6)
|
||||||
data[0] = 1
|
data[0] = 1
|
||||||
util.SafeWrite(d.tunInbound, &Packet{
|
d.tunInbound <- &Packet{
|
||||||
data: data[:],
|
data: data[:],
|
||||||
length: length + 1,
|
length: length + 1,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user