mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-05 23:36:59 +08:00
refactor: add additional [2]byte for packet length (#554)
This commit is contained in:
@@ -23,7 +23,7 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr
|
||||
}
|
||||
|
||||
defer device.Close()
|
||||
go device.forwardPacketToRemote(ctx, remoteAddr, h.forward)
|
||||
go device.handlePacket(ctx, remoteAddr, h.forward)
|
||||
go device.readFromTun(ctx)
|
||||
go device.writeToTun(ctx)
|
||||
go heartbeats(ctx, device.tun)
|
||||
@@ -43,51 +43,35 @@ type ClientDevice struct {
|
||||
forward *Forwarder
|
||||
}
|
||||
|
||||
func (d *ClientDevice) forwardPacketToRemote(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) {
|
||||
func (d *ClientDevice) handlePacket(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) {
|
||||
for ctx.Err() == nil {
|
||||
func() {
|
||||
packetConn, err := getRemotePacketConn(ctx, forward)
|
||||
if err != nil {
|
||||
plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err)
|
||||
time.Sleep(time.Second * 1)
|
||||
return
|
||||
}
|
||||
err = transportTunPacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr)
|
||||
if err != nil {
|
||||
plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err)
|
||||
}
|
||||
}()
|
||||
packetConn, err := getRemotePacketConn(ctx, forward)
|
||||
if err != nil {
|
||||
plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err)
|
||||
time.Sleep(time.Second * 1)
|
||||
continue
|
||||
}
|
||||
err = handlePacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr)
|
||||
if err != nil {
|
||||
plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (packetConn net.PacketConn, err error) {
|
||||
defer func() {
|
||||
if err != nil && packetConn != nil {
|
||||
_ = packetConn.Close()
|
||||
}
|
||||
}()
|
||||
if !forwarder.IsEmpty() {
|
||||
var conn net.Conn
|
||||
conn, err = forwarder.DialContext(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
if packetConn, ok = conn.(net.PacketConn); !ok {
|
||||
err = errors.New("not a packet connection")
|
||||
return
|
||||
}
|
||||
func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (net.PacketConn, error) {
|
||||
conn, err := forwarder.DialContext(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to dial forwarder")
|
||||
}
|
||||
|
||||
if packetConn, ok := conn.(net.PacketConn); !ok {
|
||||
return nil, errors.Errorf("failed to cast packet conn to PacketConn")
|
||||
} else {
|
||||
var lc net.ListenConfig
|
||||
packetConn, err = lc.ListenPacket(ctx, "udp", "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return packetConn, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error {
|
||||
func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error {
|
||||
errChan := make(chan error, 2)
|
||||
defer packetConn.Close()
|
||||
|
||||
@@ -105,7 +89,7 @@ func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tu
|
||||
})
|
||||
continue
|
||||
}
|
||||
_, err := packetConn.WriteTo(packet.data[:packet.length], remoteAddr)
|
||||
_, err := packetConn.WriteTo(packet.data[:packet.length+2], remoteAddr)
|
||||
config.LPool.Put(packet.data[:])
|
||||
if err != nil {
|
||||
util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr)))
|
||||
@@ -126,6 +110,7 @@ func transportTunPacketClient(ctx context.Context, tunInbound <-chan *Packet, tu
|
||||
}
|
||||
if n == 0 {
|
||||
plog.G(ctx).Warnf("Packet length 0")
|
||||
config.LPool.Put(buf[:])
|
||||
continue
|
||||
}
|
||||
util.SafeWrite(tunOutbound, &Packet{data: buf[:], length: n}, func(v *Packet) {
|
||||
@@ -174,9 +159,9 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
|
||||
|
||||
func (d *ClientDevice) writeToTun(ctx context.Context) {
|
||||
defer util.HandleCrash()
|
||||
for e := range d.tunOutbound {
|
||||
_, err := d.tun.Write(e.data[:e.length])
|
||||
config.LPool.Put(e.data[:])
|
||||
for packet := range d.tunOutbound {
|
||||
_, err := d.tun.Write(packet.data[:packet.length])
|
||||
config.LPool.Put(packet.data[:])
|
||||
if err != nil {
|
||||
plog.G(ctx).Errorf("Failed to write packet to tun device: %v", err)
|
||||
util.SafeWrite(d.errChan, err)
|
||||
|
Reference in New Issue
Block a user