refactor: add additional [2]byte for packet length (#554)

This commit is contained in:
naison
2025-04-21 21:51:01 +08:00
committed by GitHub
parent dd80717d8d
commit a3556a263d
10 changed files with 56 additions and 75 deletions

View File

@@ -34,10 +34,7 @@ func (c *bufferedTCP) Write(b []byte) (n int, err error) {
buf := config.LPool.Get().([]byte)[:] buf := config.LPool.Get().([]byte)[:]
n = copy(buf, b) n = copy(buf, b)
c.Chan <- &DatagramPacket{ c.Chan <- newDatagramPacket(buf, n)
DataLength: uint16(n),
Data: buf,
}
return n, nil return n, nil
} }

View File

@@ -73,15 +73,15 @@ func TCPForwarder(ctx context.Context, s *stack.Stack) func(stack.TransportEndpo
func WriteProxyInfo(conn net.Conn, id stack.TransportEndpointID) error { func WriteProxyInfo(conn net.Conn, id stack.TransportEndpointID) error {
var b bytes.Buffer var b bytes.Buffer
i := config.LPool.Get().([]byte)[:] buf := config.LPool.Get().([]byte)[:]
defer config.LPool.Put(i[:]) defer config.LPool.Put(buf[:])
// local port // local port
binary.BigEndian.PutUint16(i, id.LocalPort) binary.BigEndian.PutUint16(buf, id.LocalPort)
b.Write(i) b.Write(buf)
// remote port // remote port
binary.BigEndian.PutUint16(i, id.RemotePort) binary.BigEndian.PutUint16(buf, id.RemotePort)
b.Write(i) b.Write(buf)
// local address // local address
b.WriteByte(byte(id.LocalAddress.Len())) b.WriteByte(byte(id.LocalAddress.Len()))

View File

@@ -25,8 +25,11 @@ func (h *gvisorTCPHandler) readFromEndpointWriteToTCPConn(ctx context.Context, c
pktBuffer := endpoint.ReadContext(ctx) pktBuffer := endpoint.ReadContext(ctx)
if pktBuffer != nil { if pktBuffer != nil {
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pktBuffer.NetworkProtocolNumber, pktBuffer) sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pktBuffer.NetworkProtocolNumber, pktBuffer)
buf := pktBuffer.ToView().AsSlice() data := pktBuffer.ToView().AsSlice()
_, err := tcpConn.Write(buf) buf := config.LPool.Get().([]byte)[:]
n := copy(buf, data)
_, err := tcpConn.Write(buf[:n+2])
config.LPool.Put(buf)
if err != nil { if err != nil {
plog.G(ctx).Errorf("[TUN-GVISOR] Failed to write data to tun device: %v", err) plog.G(ctx).Errorf("[TUN-GVISOR] Failed to write data to tun device: %v", err)
} }
@@ -111,7 +114,7 @@ func (h *gvisorTCPHandler) readFromTCPConnWriteToEndpoint(ctx context.Context, c
func (h *gvisorTCPHandler) handlePacket(ctx context.Context, buf []byte, length int, src, dst net.IP, protocol string) error { func (h *gvisorTCPHandler) handlePacket(ctx context.Context, buf []byte, length int, src, dst net.IP, protocol string) error {
if conn, ok := h.routeMapTCP.Load(dst.String()); ok { if conn, ok := h.routeMapTCP.Load(dst.String()); ok {
plog.G(ctx).Debugf("[TCP-GVISOR] Find TCP route SRC: %s to DST: %s -> %s", src, dst, conn.(net.Conn).RemoteAddr()) plog.G(ctx).Debugf("[TCP-GVISOR] Find TCP route SRC: %s to DST: %s -> %s", src, dst, conn.(net.Conn).RemoteAddr())
dgram := newDatagramPacket(buf[:length]) dgram := newDatagramPacket(buf, length)
err := dgram.Write(conn.(net.Conn)) err := dgram.Write(conn.(net.Conn))
config.LPool.Put(buf[:]) config.LPool.Put(buf[:])
if err != nil { if err != nil {

View File

@@ -68,7 +68,7 @@ func (c *gvisorUDPConnOverTCP) Read(b []byte) (int, error) {
} }
func (c *gvisorUDPConnOverTCP) Write(b []byte) (int, error) { func (c *gvisorUDPConnOverTCP) Write(b []byte) (int, error) {
packet := newDatagramPacket(b) packet := newDatagramPacket(b, len(b)-2)
if err := packet.Write(c.Conn); err != nil { if err := packet.Write(c.Conn); err != nil {
return 0, err return 0, err
} }
@@ -114,7 +114,7 @@ func handle(ctx context.Context, tcpConn net.Conn, udpConn *net.UDPConn) {
errChan <- err errChan <- err
return return
} }
datagram, err := readDatagramPacket(tcpConn, buf[:]) datagram, err := readDatagramPacket(tcpConn, buf)
if err != nil { if err != nil {
plog.G(ctx).Errorf("[TUN-UDP] %s -> %s: %v", tcpConn.RemoteAddr(), udpConn.LocalAddr(), err) plog.G(ctx).Errorf("[TUN-UDP] %s -> %s: %v", tcpConn.RemoteAddr(), udpConn.LocalAddr(), err)
errChan <- err errChan <- err
@@ -172,7 +172,7 @@ func handle(ctx context.Context, tcpConn net.Conn, udpConn *net.UDPConn) {
errChan <- err errChan <- err
return return
} }
packet := newDatagramPacket(buf[:n]) packet := newDatagramPacket(buf, n)
if err = packet.Write(tcpConn); err != nil { if err = packet.Write(tcpConn); err != nil {
plog.G(ctx).Errorf("[TUN-UDP] Error: %s <- %s : %s", tcpConn.RemoteAddr(), tcpConn.LocalAddr(), err) plog.G(ctx).Errorf("[TUN-UDP] Error: %s <- %s : %s", tcpConn.RemoteAddr(), tcpConn.LocalAddr(), err)
errChan <- err errChan <- err

View File

@@ -29,7 +29,7 @@ func TCPTransporter(tlsInfo map[string][]byte) Transporter {
} }
func (tr *tcpTransporter) Dial(ctx context.Context, addr string) (net.Conn, error) { func (tr *tcpTransporter) Dial(ctx context.Context, addr string) (net.Conn, error) {
dialer := &net.Dialer{Timeout: config.DialTimeout} dialer := &net.Dialer{Timeout: config.DialTimeout, KeepAlive: config.KeepAliveTime}
conn, err := dialer.DialContext(ctx, "tcp", addr) conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -61,7 +61,7 @@ func (h *UDPOverTCPHandler) Handle(ctx context.Context, tcpConn net.Conn) {
for ctx.Err() == nil { for ctx.Err() == nil {
buf := config.LPool.Get().([]byte)[:] buf := config.LPool.Get().([]byte)[:]
datagram, err := readDatagramPacket(tcpConn, buf[:]) datagram, err := readDatagramPacket(tcpConn, buf)
if err != nil { if err != nil {
plog.G(ctx).Errorf("[TCP] Failed to read from %s -> %s: %v", tcpConn.RemoteAddr(), tcpConn.LocalAddr(), err) plog.G(ctx).Errorf("[TCP] Failed to read from %s -> %s: %v", tcpConn.RemoteAddr(), tcpConn.LocalAddr(), err)
config.LPool.Put(buf[:]) config.LPool.Put(buf[:])
@@ -160,7 +160,7 @@ func (c *UDPConnOverTCP) ReadFrom(b []byte) (int, net.Addr, error) {
} }
func (c *UDPConnOverTCP) WriteTo(b []byte, _ net.Addr) (int, error) { func (c *UDPConnOverTCP) WriteTo(b []byte, _ net.Addr) (int, error) {
packet := newDatagramPacket(b) packet := newDatagramPacket(b, len(b)-2)
if err := packet.Write(c.Conn); err != nil { if err := packet.Write(c.Conn); err != nil {
return 0, err return 0, err
} }

View File

@@ -59,7 +59,7 @@ func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) {
defer device.Close() defer device.Close()
go device.readFromTUN(ctx) go device.readFromTUN(ctx)
go device.writeToTUN(ctx) go device.writeToTUN(ctx)
go device.transport(ctx, h.node.Addr, h.routeMapUDP, h.routeMapTCP) go device.handlePacket(ctx, h.node.Addr, h.routeMapUDP, h.routeMapTCP)
select { select {
case err := <-device.errChan: case err := <-device.errChan:
@@ -131,7 +131,7 @@ func (d *Device) Close() {
util.SafeClose(TCPPacketChan) util.SafeClose(TCPPacketChan)
} }
func (d *Device) transport(ctx context.Context, addr string, routeMapUDP *sync.Map, routeMapTCP *sync.Map) { func (d *Device) handlePacket(ctx context.Context, addr string, routeMapUDP *sync.Map, routeMapTCP *sync.Map) {
packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", addr) packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", addr)
if err != nil { if err != nil {
util.SafeWrite(d.errChan, err) util.SafeWrite(d.errChan, err)
@@ -268,7 +268,7 @@ func (p *Peer) routeTUN(ctx context.Context) {
} }
} else if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { } else if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok {
plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr()) plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr())
dgram := newDatagramPacket(packet.data[:packet.length]) dgram := newDatagramPacket(packet.data, packet.length)
err := dgram.Write(conn.(net.Conn)) err := dgram.Write(conn.(net.Conn))
config.LPool.Put(packet.data[:]) config.LPool.Put(packet.data[:])
if err != nil { if err != nil {

View File

@@ -23,7 +23,7 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr
} }
defer device.Close() defer device.Close()
go device.forwardPacketToRemote(ctx, remoteAddr, h.forward) go device.handlePacket(ctx, remoteAddr, h.forward)
go device.readFromTun(ctx) go device.readFromTun(ctx)
go device.writeToTun(ctx) go device.writeToTun(ctx)
go heartbeats(ctx, device.tun) go heartbeats(ctx, device.tun)
@@ -43,51 +43,35 @@ type ClientDevice struct {
forward *Forwarder forward *Forwarder
} }
func (d *ClientDevice) forwardPacketToRemote(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) { func (d *ClientDevice) handlePacket(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) {
for ctx.Err() == nil { for ctx.Err() == nil {
func() { packetConn, err := getRemotePacketConn(ctx, forward)
packetConn, err := getRemotePacketConn(ctx, forward) if err != nil {
if err != nil { plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err)
plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err) time.Sleep(time.Second * 1)
time.Sleep(time.Second * 1) continue
return }
} err = handlePacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr)
err = transportTunPacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr) if err != nil {
if err != nil { plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err)
plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err) }
}
}()
} }
} }
func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (packetConn net.PacketConn, err error) { func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (net.PacketConn, error) {
defer func() { conn, err := forwarder.DialContext(ctx)
if err != nil && packetConn != nil { if err != nil {
_ = packetConn.Close() return nil, errors.Wrap(err, "failed to dial forwarder")
} }
}()
if !forwarder.IsEmpty() { if packetConn, ok := conn.(net.PacketConn); !ok {
var conn net.Conn return nil, errors.Errorf("failed to cast packet conn to PacketConn")
conn, err = forwarder.DialContext(ctx)
if err != nil {
return
}
var ok bool
if packetConn, ok = conn.(net.PacketConn); !ok {
err = errors.New("not a packet connection")
return
}
} else { } else {
var lc net.ListenConfig return packetConn, nil
packetConn, err = lc.ListenPacket(ctx, "udp", "")
if err != nil {
return
}
} }
return
} }
func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error { func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error {
errChan := make(chan error, 2) errChan := make(chan error, 2)
defer packetConn.Close() defer packetConn.Close()
@@ -105,7 +89,7 @@ func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tu
}) })
continue continue
} }
_, err := packetConn.WriteTo(packet.data[:packet.length], remoteAddr) _, err := packetConn.WriteTo(packet.data[:packet.length+2], remoteAddr)
config.LPool.Put(packet.data[:]) config.LPool.Put(packet.data[:])
if err != nil { if err != nil {
util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr))) util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr)))
@@ -126,6 +110,7 @@ func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tu
} }
if n == 0 { if n == 0 {
plog.G(ctx).Warnf("Packet length 0") plog.G(ctx).Warnf("Packet length 0")
config.LPool.Put(buf[:])
continue continue
} }
util.SafeWrite(tunOutbound, &Packet{data: buf[:], length: n}, func(v *Packet) { util.SafeWrite(tunOutbound, &Packet{data: buf[:], length: n}, func(v *Packet) {
@@ -174,9 +159,9 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
func (d *ClientDevice) writeToTun(ctx context.Context) { func (d *ClientDevice) writeToTun(ctx context.Context) {
defer util.HandleCrash() defer util.HandleCrash()
for e := range d.tunOutbound { for packet := range d.tunOutbound {
_, err := d.tun.Write(e.data[:e.length]) _, err := d.tun.Write(packet.data[:packet.length])
config.LPool.Put(e.data[:]) config.LPool.Put(packet.data[:])
if err != nil { if err != nil {
plog.G(ctx).Errorf("Failed to write packet to tun device: %v", err) plog.G(ctx).Errorf("Failed to write packet to tun device: %v", err)
util.SafeWrite(d.errChan, err) util.SafeWrite(d.errChan, err)

View File

@@ -3,8 +3,6 @@ package core
import ( import (
"encoding/binary" "encoding/binary"
"io" "io"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
) )
type DatagramPacket struct { type DatagramPacket struct {
@@ -12,9 +10,9 @@ type DatagramPacket struct {
Data []byte // []byte Data []byte // []byte
} }
func newDatagramPacket(data []byte) (r *DatagramPacket) { func newDatagramPacket(data []byte, length int) (r *DatagramPacket) {
return &DatagramPacket{ return &DatagramPacket{
DataLength: uint16(len(data)), DataLength: uint16(length),
Data: data, Data: data,
} }
} }
@@ -34,10 +32,8 @@ func readDatagramPacket(r io.Reader, b []byte) (*DatagramPacket, error) {
} }
func (d *DatagramPacket) Write(w io.Writer) error { func (d *DatagramPacket) Write(w io.Writer) error {
buf := config.LPool.Get().([]byte)[:] n := copy(d.Data[2:], d.Data[:d.DataLength])
defer config.LPool.Put(buf[:]) binary.BigEndian.PutUint16(d.Data[:2], d.DataLength)
binary.BigEndian.PutUint16(buf[:2], d.DataLength) _, err := w.Write(d.Data[:n+2])
n := copy(buf[2:], d.Data[:d.DataLength])
_, err := w.Write(buf[:n+2])
return err return err
} }

View File

@@ -262,7 +262,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err) plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err)
} }
pkgutil.SafeWrite(chDone, true) chDone <- true
}() }()
// start local -> remote data transfer // start local -> remote data transfer
@@ -273,7 +273,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err) plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err)
} }
pkgutil.SafeWrite(chDone, true) chDone <- true
}() }()
select { select {