diff --git a/pkg/core/bufferedtcp.go b/pkg/core/bufferedtcp.go index b4fe0ce2..cbcc21fd 100644 --- a/pkg/core/bufferedtcp.go +++ b/pkg/core/bufferedtcp.go @@ -34,10 +34,7 @@ func (c *bufferedTCP) Write(b []byte) (n int, err error) { buf := config.LPool.Get().([]byte)[:] n = copy(buf, b) - c.Chan <- &DatagramPacket{ - DataLength: uint16(n), - Data: buf, - } + c.Chan <- newDatagramPacket(buf, n) return n, nil } diff --git a/pkg/core/gvisortcpforwarder.go b/pkg/core/gvisortcpforwarder.go index 6b8da8f8..e109fdb7 100644 --- a/pkg/core/gvisortcpforwarder.go +++ b/pkg/core/gvisortcpforwarder.go @@ -73,15 +73,15 @@ func TCPForwarder(ctx context.Context, s *stack.Stack) func(stack.TransportEndpo func WriteProxyInfo(conn net.Conn, id stack.TransportEndpointID) error { var b bytes.Buffer - i := config.LPool.Get().([]byte)[:] - defer config.LPool.Put(i[:]) + buf := config.LPool.Get().([]byte)[:] + defer config.LPool.Put(buf[:]) // local port - binary.BigEndian.PutUint16(i, id.LocalPort) - b.Write(i) + binary.BigEndian.PutUint16(buf, id.LocalPort) + b.Write(buf) // remote port - binary.BigEndian.PutUint16(i, id.RemotePort) - b.Write(i) + binary.BigEndian.PutUint16(buf, id.RemotePort) + b.Write(buf) // local address b.WriteByte(byte(id.LocalAddress.Len())) diff --git a/pkg/core/gvisortunendpoint.go b/pkg/core/gvisortunendpoint.go index 161c8a0e..2d5e0ef8 100755 --- a/pkg/core/gvisortunendpoint.go +++ b/pkg/core/gvisortunendpoint.go @@ -25,8 +25,11 @@ func (h *gvisorTCPHandler) readFromEndpointWriteToTCPConn(ctx context.Context, c pktBuffer := endpoint.ReadContext(ctx) if pktBuffer != nil { sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pktBuffer.NetworkProtocolNumber, pktBuffer) - buf := pktBuffer.ToView().AsSlice() - _, err := tcpConn.Write(buf) + data := pktBuffer.ToView().AsSlice() + buf := config.LPool.Get().([]byte)[:] + n := copy(buf, data) + _, err := tcpConn.Write(buf[:n+2]) + config.LPool.Put(buf) if err != nil { plog.G(ctx).Errorf("[TUN-GVISOR] Failed to write data to tun device: %v", err) } @@ -111,7 +114,7 @@ func (h *gvisorTCPHandler) readFromTCPConnWriteToEndpoint(ctx context.Context, c func (h *gvisorTCPHandler) handlePacket(ctx context.Context, buf []byte, length int, src, dst net.IP, protocol string) error { if conn, ok := h.routeMapTCP.Load(dst.String()); ok { plog.G(ctx).Debugf("[TCP-GVISOR] Find TCP route SRC: %s to DST: %s -> %s", src, dst, conn.(net.Conn).RemoteAddr()) - dgram := newDatagramPacket(buf[:length]) + dgram := newDatagramPacket(buf, length) err := dgram.Write(conn.(net.Conn)) config.LPool.Put(buf[:]) if err != nil { diff --git a/pkg/core/gvisorudphandler.go b/pkg/core/gvisorudphandler.go index f26f7539..7aef9904 100644 --- a/pkg/core/gvisorudphandler.go +++ b/pkg/core/gvisorudphandler.go @@ -68,7 +68,7 @@ func (c *gvisorUDPConnOverTCP) Read(b []byte) (int, error) { } func (c *gvisorUDPConnOverTCP) Write(b []byte) (int, error) { - packet := newDatagramPacket(b) + packet := newDatagramPacket(b, len(b)-2) if err := packet.Write(c.Conn); err != nil { return 0, err } @@ -114,7 +114,7 @@ func handle(ctx context.Context, tcpConn net.Conn, udpConn *net.UDPConn) { errChan <- err return } - datagram, err := readDatagramPacket(tcpConn, buf[:]) + datagram, err := readDatagramPacket(tcpConn, buf) if err != nil { plog.G(ctx).Errorf("[TUN-UDP] %s -> %s: %v", tcpConn.RemoteAddr(), udpConn.LocalAddr(), err) errChan <- err @@ -172,7 +172,7 @@ func handle(ctx context.Context, tcpConn net.Conn, udpConn *net.UDPConn) { errChan <- err return } - packet := newDatagramPacket(buf[:n]) + packet := newDatagramPacket(buf, n) if err = packet.Write(tcpConn); err != nil { plog.G(ctx).Errorf("[TUN-UDP] Error: %s <- %s : %s", tcpConn.RemoteAddr(), tcpConn.LocalAddr(), err) errChan <- err diff --git a/pkg/core/tcp.go b/pkg/core/tcp.go index 236c9330..f3c8df96 100644 --- a/pkg/core/tcp.go +++ b/pkg/core/tcp.go @@ -29,7 +29,7 @@ func TCPTransporter(tlsInfo map[string][]byte) Transporter { } func (tr *tcpTransporter) Dial(ctx context.Context, addr string) (net.Conn, error) { - dialer := &net.Dialer{Timeout: config.DialTimeout} + dialer := &net.Dialer{Timeout: config.DialTimeout, KeepAlive: config.KeepAliveTime} conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err diff --git a/pkg/core/tcphandler.go b/pkg/core/tcphandler.go index f1eb1b3f..e79a2641 100644 --- a/pkg/core/tcphandler.go +++ b/pkg/core/tcphandler.go @@ -61,7 +61,7 @@ func (h *UDPOverTCPHandler) Handle(ctx context.Context, tcpConn net.Conn) { for ctx.Err() == nil { buf := config.LPool.Get().([]byte)[:] - datagram, err := readDatagramPacket(tcpConn, buf[:]) + datagram, err := readDatagramPacket(tcpConn, buf) if err != nil { plog.G(ctx).Errorf("[TCP] Failed to read from %s -> %s: %v", tcpConn.RemoteAddr(), tcpConn.LocalAddr(), err) config.LPool.Put(buf[:]) @@ -160,7 +160,7 @@ func (c *UDPConnOverTCP) ReadFrom(b []byte) (int, net.Addr, error) { } func (c *UDPConnOverTCP) WriteTo(b []byte, _ net.Addr) (int, error) { - packet := newDatagramPacket(b) + packet := newDatagramPacket(b, len(b)-2) if err := packet.Write(c.Conn); err != nil { return 0, err } diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index c4e56d62..9add0586 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -59,7 +59,7 @@ func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) { defer device.Close() go device.readFromTUN(ctx) go device.writeToTUN(ctx) - go device.transport(ctx, h.node.Addr, h.routeMapUDP, h.routeMapTCP) + go device.handlePacket(ctx, h.node.Addr, h.routeMapUDP, h.routeMapTCP) select { case err := <-device.errChan: @@ -131,7 +131,7 @@ func (d *Device) Close() { util.SafeClose(TCPPacketChan) } -func (d *Device) transport(ctx context.Context, addr string, routeMapUDP *sync.Map, routeMapTCP *sync.Map) { +func (d *Device) handlePacket(ctx context.Context, addr string, routeMapUDP *sync.Map, routeMapTCP *sync.Map) { packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", addr) if err != nil { util.SafeWrite(d.errChan, err) @@ -268,7 +268,7 @@ func (p *Peer) routeTUN(ctx context.Context) { } } else if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr()) - dgram := newDatagramPacket(packet.data[:packet.length]) + dgram := newDatagramPacket(packet.data, packet.length) err := dgram.Write(conn.(net.Conn)) config.LPool.Put(packet.data[:]) if err != nil { diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index 34a658be..a745b561 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -23,7 +23,7 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr } defer device.Close() - go device.forwardPacketToRemote(ctx, remoteAddr, h.forward) + go device.handlePacket(ctx, remoteAddr, h.forward) go device.readFromTun(ctx) go device.writeToTun(ctx) go heartbeats(ctx, device.tun) @@ -43,51 +43,35 @@ type ClientDevice struct { forward *Forwarder } -func (d *ClientDevice) forwardPacketToRemote(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) { +func (d *ClientDevice) handlePacket(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) { for ctx.Err() == nil { - func() { - packetConn, err := getRemotePacketConn(ctx, forward) - if err != nil { - plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err) - time.Sleep(time.Second * 1) - return - } - err = transportTunPacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr) - if err != nil { - plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err) - } - }() + packetConn, err := getRemotePacketConn(ctx, forward) + if err != nil { + plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err) + time.Sleep(time.Second * 1) + continue + } + err = handlePacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr) + if err != nil { + plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err) + } } } -func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (packetConn net.PacketConn, err error) { - defer func() { - if err != nil && packetConn != nil { - _ = packetConn.Close() - } - }() - if !forwarder.IsEmpty() { - var conn net.Conn - conn, err = forwarder.DialContext(ctx) - if err != nil { - return - } - var ok bool - if packetConn, ok = conn.(net.PacketConn); !ok { - err = errors.New("not a packet connection") - return - } +func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (net.PacketConn, error) { + conn, err := forwarder.DialContext(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to dial forwarder") + } + + if packetConn, ok := conn.(net.PacketConn); !ok { + return nil, errors.Errorf("failed to cast packet conn to PacketConn") } else { - var lc net.ListenConfig - packetConn, err = lc.ListenPacket(ctx, "udp", "") - if err != nil { - return - } + return packetConn, nil } - return } -func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error { +func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error { errChan := make(chan error, 2) defer packetConn.Close() @@ -105,7 +89,7 @@ func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tu }) continue } - _, err := packetConn.WriteTo(packet.data[:packet.length], remoteAddr) + _, err := packetConn.WriteTo(packet.data[:packet.length+2], remoteAddr) config.LPool.Put(packet.data[:]) if err != nil { util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr))) @@ -126,6 +110,7 @@ func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tu } if n == 0 { plog.G(ctx).Warnf("Packet length 0") + config.LPool.Put(buf[:]) continue } util.SafeWrite(tunOutbound, &Packet{data: buf[:], length: n}, func(v *Packet) { @@ -174,9 +159,9 @@ func (d *ClientDevice) readFromTun(ctx context.Context) { func (d *ClientDevice) writeToTun(ctx context.Context) { defer util.HandleCrash() - for e := range d.tunOutbound { - _, err := d.tun.Write(e.data[:e.length]) - config.LPool.Put(e.data[:]) + for packet := range d.tunOutbound { + _, err := d.tun.Write(packet.data[:packet.length]) + config.LPool.Put(packet.data[:]) if err != nil { plog.G(ctx).Errorf("Failed to write packet to tun device: %v", err) util.SafeWrite(d.errChan, err) diff --git a/pkg/core/udpovertcp.go b/pkg/core/udpovertcp.go index a204894a..89e02ac2 100644 --- a/pkg/core/udpovertcp.go +++ b/pkg/core/udpovertcp.go @@ -3,8 +3,6 @@ package core import ( "encoding/binary" "io" - - "github.com/wencaiwulue/kubevpn/v2/pkg/config" ) type DatagramPacket struct { @@ -12,9 +10,9 @@ type DatagramPacket struct { Data []byte // []byte } -func newDatagramPacket(data []byte) (r *DatagramPacket) { +func newDatagramPacket(data []byte, length int) (r *DatagramPacket) { return &DatagramPacket{ - DataLength: uint16(len(data)), + DataLength: uint16(length), Data: data, } } @@ -34,10 +32,8 @@ func readDatagramPacket(r io.Reader, b []byte) (*DatagramPacket, error) { } func (d *DatagramPacket) Write(w io.Writer) error { - buf := config.LPool.Get().([]byte)[:] - defer config.LPool.Put(buf[:]) - binary.BigEndian.PutUint16(buf[:2], d.DataLength) - n := copy(buf[2:], d.Data[:d.DataLength]) - _, err := w.Write(buf[:n+2]) + n := copy(d.Data[2:], d.Data[:d.DataLength]) + binary.BigEndian.PutUint16(d.Data[:2], d.DataLength) + _, err := w.Write(d.Data[:n+2]) return err } diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 8d375d3a..f7da47b2 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -262,7 +262,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) { if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err) } - pkgutil.SafeWrite(chDone, true) + chDone <- true }() // start local -> remote data transfer @@ -273,7 +273,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) { if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err) } - pkgutil.SafeWrite(chDone, true) + chDone <- true }() select {