refactor: use tcp conn instead of packet conn

This commit is contained in:
fengcaiwen
2025-04-27 23:03:45 +08:00
parent a6ec321e46
commit c4540b1930
3 changed files with 32 additions and 47 deletions

View File

@@ -128,7 +128,7 @@ func (h *UDPOverTCPHandler) removeFromRouteMapTCP(ctx context.Context, tcpConn n
})
}
var _ net.PacketConn = (*UDPConnOverTCP)(nil)
var _ net.Conn = (*UDPConnOverTCP)(nil)
// UDPConnOverTCP fake udp connection over tcp connection
type UDPConnOverTCP struct {
@@ -141,20 +141,20 @@ func newUDPConnOverTCP(ctx context.Context, conn net.Conn) (net.Conn, error) {
return &UDPConnOverTCP{ctx: ctx, Conn: conn}, nil
}
func (c *UDPConnOverTCP) ReadFrom(b []byte) (int, net.Addr, error) {
func (c *UDPConnOverTCP) Read(b []byte) (int, error) {
select {
case <-c.ctx.Done():
return 0, nil, c.ctx.Err()
return 0, c.ctx.Err()
default:
datagram, err := readDatagramPacket(c.Conn, b)
if err != nil {
return 0, nil, err
return 0, err
}
return int(datagram.DataLength), nil, nil
return int(datagram.DataLength), nil
}
}
func (c *UDPConnOverTCP) WriteTo(b []byte, _ net.Addr) (int, error) {
func (c *UDPConnOverTCP) Write(b []byte) (int, error) {
buf := config.LPool.Get().([]byte)[:]
n := copy(buf, b)
defer config.LPool.Put(buf)

View File

@@ -37,12 +37,7 @@ func TunHandler(forward *Forwarder, node *Node) Handler {
func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) {
if remote := h.node.Remote; remote != "" {
remoteAddr, err := net.ResolveUDPAddr("udp", remote)
if err != nil {
plog.G(ctx).Errorf("Failed to resolve udp addr %s: %v", remote, err)
return
}
h.HandleClient(ctx, tun, remoteAddr)
h.HandleClient(ctx, tun)
} else {
h.HandleServer(ctx, tun)
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr *net.UDPAddr) {
func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
device := &ClientDevice{
tun: tun,
tunInbound: make(chan *Packet, MaxSize),
@@ -23,7 +23,7 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr
}
defer device.Close()
go device.handlePacket(ctx, remoteAddr, h.forward)
go device.handlePacket(ctx, h.forward)
go device.readFromTun(ctx)
go device.writeToTun(ctx)
go heartbeats(ctx, device.tun)
@@ -43,56 +43,40 @@ type ClientDevice struct {
forward *Forwarder
}
func (d *ClientDevice) handlePacket(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) {
func (d *ClientDevice) handlePacket(ctx context.Context, forward *Forwarder) {
for ctx.Err() == nil {
packetConn, err := getRemotePacketConn(ctx, forward)
conn, err := forwardConn(ctx, forward)
if err != nil {
plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err)
plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), forward.node.Remote, err)
time.Sleep(time.Second * 1)
continue
}
err = handlePacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr)
err = handlePacketClient(ctx, d.tunInbound, d.tunOutbound, conn)
if err != nil {
plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err)
plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", conn.RemoteAddr(), err)
}
}
}
func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (net.PacketConn, error) {
func forwardConn(ctx context.Context, forwarder *Forwarder) (net.Conn, error) {
conn, err := forwarder.DialContext(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to dial forwarder")
}
if packetConn, ok := conn.(net.PacketConn); !ok {
return nil, errors.Errorf("failed to cast packet conn to PacketConn")
} else {
return packetConn, nil
}
return conn, nil
}
func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error {
func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, conn net.Conn) error {
errChan := make(chan error, 2)
defer packetConn.Close()
defer conn.Close()
go func() {
defer util.HandleCrash()
for packet := range tunInbound {
if packet.src.Equal(packet.dst) {
util.SafeWrite(tunOutbound, packet, func(v *Packet) {
var p = "unknown"
if _, _, protocol, err := util.ParseIP(v.data[:v.length]); err == nil {
p = layers.IPProtocol(protocol).String()
}
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, p, v.length)
})
continue
}
_, err := packetConn.WriteTo(packet.data[:packet.length], remoteAddr)
_, err := conn.Write(packet.data[:packet.length])
config.LPool.Put(packet.data[:])
if err != nil {
util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr)))
util.SafeWrite(errChan, errors.Wrap(err, "failed to write packet to remote"))
return
}
}
@@ -102,10 +86,10 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo
defer util.HandleCrash()
for {
buf := config.LPool.Get().([]byte)[:]
n, _, err := packetConn.ReadFrom(buf[:])
n, err := conn.Read(buf[:])
if err != nil {
config.LPool.Put(buf[:])
util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr)))
util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", conn.RemoteAddr())))
return
}
if n == 0 {
@@ -115,7 +99,7 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo
}
util.SafeWrite(tunOutbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) {
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", packetConn.LocalAddr(), remoteAddr, v.length)
plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length)
})
}
}()
@@ -150,10 +134,16 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
continue
}
plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n)
util.SafeWrite(d.tunInbound, NewPacket(buf[:], n, src, dst), func(v *Packet) {
packet := NewPacket(buf[:], n, src, dst)
f := func(v *Packet) {
config.LPool.Put(v.data[:])
plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length)
})
}
if packet.src.Equal(packet.dst) {
util.SafeWrite(d.tunOutbound, packet, f)
continue
}
util.SafeWrite(d.tunInbound, packet, f)
}
}
@@ -188,7 +178,7 @@ func heartbeats(ctx context.Context, tun net.Conn) {
return
}
ticker := time.NewTicker(time.Second * 60)
ticker := time.NewTicker(config.KeepAliveTime)
defer ticker.Stop()
for ; ctx.Err() == nil; <-ticker.C {