refactor: refactor code (#687)

This commit is contained in:
naison
2025-08-06 18:38:40 +08:00
committed by GitHub
parent a13540a258
commit 6197138ad6
14 changed files with 51 additions and 118 deletions

View File

@@ -17,8 +17,9 @@ import (
func CmdConnection(f cmdutil.Factory) *cobra.Command {
cmd := &cobra.Command{
Use: "connection",
Short: "Connection management",
Use: "connection",
Short: "Connection management",
Aliases: []string{"conn"},
}
cmd.AddCommand(cmdConnectionList(f))
cmd.AddCommand(cmdConnectionUse(f))
@@ -34,6 +35,8 @@ func cmdConnectionList(f cmdutil.Factory) *cobra.Command {
Example: templates.Examples(i18n.T(`
# list all connections
kubevpn connection ls
# list connections by alias conn
kubevpn conn ls
`)),
PreRunE: func(cmd *cobra.Command, args []string) (err error) {
// startup daemon process and sudo process
@@ -68,7 +71,7 @@ func cmdConnectionUse(f cmdutil.Factory) *cobra.Command {
Use a specific connection.
`)),
Example: templates.Examples(i18n.T(`
# use a specific connection, change current connection to special id, leave or unsync will use this connection
# use a specific connection, change current connection to special id, cmd sync/unsync will use this connection
kubevpn connection use 03dc50feb8c3
`)),
PreRunE: func(cmd *cobra.Command, args []string) (err error) {

View File

@@ -192,9 +192,9 @@ var (
ipv4: 20
ipv6: 40
mtu: 1417
mtu = 1500 - ip header(20/40 v4/v6) - tcp header (20) - tls1.3(5+1+16) - packet over tcp(length(2)+remark(1)) = 1415
*/
DefaultMTU = 1500 - max(20, 40) - 20 - 5 - 2 - 16
DefaultMTU = 1500 - max(20, 40) - 20 - (5 + 1 + 16) - (2 + 1)
)
var (

View File

@@ -6,22 +6,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
func ICMPForwarder(ctx context.Context, s *stack.Stack) func(stack.TransportEndpointID, *stack.PacketBuffer) bool {
return func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool {
return func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
defer pkt.DecRef()
plog.G(ctx).Infof("[TUN-ICMP] LocalPort: %d, LocalAddress: %s, RemotePort: %d, RemoteAddress %s",
id.LocalPort, id.LocalAddress.String(), id.RemotePort, id.RemoteAddress.String(),
)
ctx1, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
ok, err := util.PingOnce(ctx1, id.RemoteAddress.String(), id.LocalAddress.String())
if err != nil {
plog.G(ctx).Errorf("[TUN-ICMP] Failed to ping dst %s from src %s",
id.LocalAddress.String(), id.RemoteAddress.String(),
)
}
return ok
return true
}
}

View File

@@ -46,18 +46,15 @@ func LocalTCPForwarder(ctx context.Context, s *stack.Stack) func(stack.Transport
// 2, dial proxy
var host string
var network string
if id.LocalAddress.To4() != (tcpip.Address{}) {
host = "127.0.0.1"
network = "tcp4"
} else {
host = net.IPv6loopback.String()
network = "tcp6"
}
port := fmt.Sprintf("%d", id.LocalPort)
var d = net.Dialer{Timeout: time.Second * 5}
var remote net.Conn
remote, err = d.DialContext(ctx, network, net.JoinHostPort(host, port))
remote, err = d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
if err != nil {
plog.G(ctx).Errorf("[TUN-TCP] Failed to connect addr %s: %v", net.JoinHostPort(host, port), err)
return

View File

@@ -46,8 +46,8 @@ func (h *gvisorLocalHandler) Run(ctx context.Context) {
readFromEndpointWriteToTun(ctx, endpoint, h.outbound)
util.SafeClose(h.errChan)
}()
stack := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, "[gVISOR] "))
defer stack.Destroy()
s := NewLocalStack(ctx, sniffer.NewWithPrefix(endpoint, "[gVISOR] "))
defer s.Destroy()
select {
case <-h.errChan:
return

View File

@@ -2,11 +2,7 @@ package core
import (
"context"
"net"
"github.com/google/gopacket/layers"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -21,10 +17,10 @@ import (
func readFromEndpointWriteToTun(ctx context.Context, endpoint *channel.Endpoint, out chan<- *Packet) {
for ctx.Err() == nil {
pktBuffer := endpoint.ReadContext(ctx)
if pktBuffer != nil {
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pktBuffer.NetworkProtocolNumber, pktBuffer)
data := pktBuffer.ToView().AsSlice()
pkt := endpoint.ReadContext(ctx)
if pkt != nil {
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt)
data := pkt.ToView().AsSlice()
buf := config.LPool.Get().([]byte)[:]
n := copy(buf[1:], data)
buf[0] = 0
@@ -47,39 +43,18 @@ func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet
// Try to determine network protocol number, default zero.
var protocol tcpip.NetworkProtocolNumber
var ipProtocol int
var src, dst net.IP
// TUN interface with IFF_NO_PI enabled, thus
// we need to determine protocol from version field
if util.IsIPv4(packet.data[1:packet.length]) {
protocol = header.IPv4ProtocolNumber
ipHeader, err := ipv4.ParseHeader(packet.data[1:packet.length])
if err != nil {
plog.G(ctx).Errorf("Failed to parse IPv4 header: %v", err)
config.LPool.Put(packet.data[:])
continue
}
ipProtocol = ipHeader.Protocol
src = ipHeader.Src
dst = ipHeader.Dst
} else if util.IsIPv6(packet.data[1:packet.length]) {
protocol = header.IPv6ProtocolNumber
ipHeader, err := ipv6.ParseHeader(packet.data[1:packet.length])
if err != nil {
plog.G(ctx).Errorf("[TCP-GVISOR] Failed to parse IPv6 header: %s", err.Error())
config.LPool.Put(packet.data[:])
continue
}
ipProtocol = ipHeader.NextHeader
src = ipHeader.Src
dst = ipHeader.Dst
} else {
plog.G(ctx).Errorf("[TCP-GVISOR] Unknown packet")
config.LPool.Put(packet.data[:])
continue
}
ipProto := layers.IPProtocol(ipProtocol)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: 0,
Payload: buffer.MakeWithData(packet.data[1:packet.length]),
@@ -88,6 +63,5 @@ func readFromGvisorInboundWriteToEndpoint(ctx context.Context, in <-chan *Packet
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionRecv, protocol, pkt)
endpoint.InjectInbound(protocol, pkt)
pkt.DecRef()
plog.G(ctx).Debugf("[TCP-GVISOR] Write to gvisor. SRC: %s, DST: %s, IPProtocol: %s, Protocol: %v, Length: %d", src, dst, ipProto.String(), protocol, packet.length)
}
}

View File

@@ -29,13 +29,10 @@ func LocalUDPForwarder(ctx context.Context, s *stack.Stack) func(id stack.Transp
Port: int(id.RemotePort),
}
var ip net.IP
var network string
if id.LocalAddress.To4() != (tcpip.Address{}) {
ip = net.ParseIP("127.0.0.1")
network = "udp4"
} else {
ip = net.IPv6loopback
network = "udp6"
}
dst := &net.UDPAddr{
IP: ip,
@@ -50,7 +47,7 @@ func LocalUDPForwarder(ctx context.Context, s *stack.Stack) func(id stack.Transp
}
// dial dst
remote, err1 := net.DialUDP(network, nil, dst)
remote, err1 := net.DialUDP("udp", nil, dst)
if err1 != nil {
plog.G(ctx).Errorf("[TUN-UDP] Failed to connect dst: %s: %v", dst.String(), err1)
return

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/pkg/errors"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -43,16 +42,10 @@ func TCPForwarder(ctx context.Context, s *stack.Stack) func(stack.TransportEndpo
}()
// 2, dial proxy
host := id.LocalAddress.String()
var network string
if id.LocalAddress.To4() != (tcpip.Address{}) {
network = "tcp4"
} else {
network = "tcp6"
}
port := fmt.Sprintf("%d", id.LocalPort)
var remote net.Conn
var d = net.Dialer{Timeout: time.Second * 5}
remote, err = d.DialContext(ctx, network, net.JoinHostPort(host, port))
remote, err = d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
if err != nil {
plog.G(ctx).Errorf("[TUN-TCP] Failed to connect addr %s: %v", net.JoinHostPort(host, port), err)
return

View File

@@ -53,8 +53,8 @@ func (h *gvisorTCPHandler) handle(ctx context.Context, tcpConn net.Conn) {
h.readFromEndpointWriteToTCPConn(ctx, tcpConn, endpoint)
util.SafeClose(errChan)
}()
stack := NewStack(ctx, sniffer.NewWithPrefix(endpoint, "[gVISOR] "))
defer stack.Destroy()
s := NewStack(ctx, sniffer.NewWithPrefix(endpoint, "[gVISOR] "))
defer s.Destroy()
select {
case <-errChan:
return

View File

@@ -22,10 +22,10 @@ import (
func (h *gvisorTCPHandler) readFromEndpointWriteToTCPConn(ctx context.Context, conn net.Conn, endpoint *channel.Endpoint) {
tcpConn, _ := newGvisorUDPConnOverTCP(ctx, conn)
for ctx.Err() == nil {
pktBuffer := endpoint.ReadContext(ctx)
if pktBuffer != nil {
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pktBuffer.NetworkProtocolNumber, pktBuffer)
data := pktBuffer.ToView().AsSlice()
pkt := endpoint.ReadContext(ctx)
if pkt != nil {
sniffer.LogPacket("[gVISOR] ", sniffer.DirectionSend, pkt.NetworkProtocolNumber, pkt)
data := pkt.ToView().AsSlice()
buf := config.LPool.Get().([]byte)[:]
n := copy(buf[1:], data)
buf[0] = 0

View File

@@ -7,7 +7,6 @@ import (
"time"
"github.com/pkg/errors"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -40,15 +39,8 @@ func UDPForwarder(ctx context.Context, s *stack.Stack) func(id stack.TransportEn
return
}
var network string
if id.LocalAddress.To4() != (tcpip.Address{}) {
network = "udp4"
} else {
network = "udp6"
}
// dial dst
remote, err1 := net.DialUDP(network, nil, dst)
remote, err1 := net.DialUDP("udp", nil, dst)
if err1 != nil {
plog.G(ctx).Errorf("[TUN-UDP] Failed to connect dst: %s: %v", dst.String(), err1)
return

View File

@@ -67,6 +67,7 @@ func (svr *Server) Disconnect(resp rpc.Daemon_DisconnectServer) (err error) {
}
}
svr.connections = nil
svr.currentConnectionID = ""
case req.GetConnectionID() != "":
var connects = *new(handler.Connects)
for i := 0; i < len(svr.connections); i++ {
@@ -84,6 +85,11 @@ func (svr *Server) Disconnect(resp rpc.Daemon_DisconnectServer) (err error) {
connect.Cleanup(ctx)
}
}
if svr.currentConnectionID == req.GetConnectionID() {
for _, connection := range svr.connections {
svr.currentConnectionID = connection.GetConnectionID()
}
}
case req.KubeconfigBytes != nil && req.Namespace != nil:
err = disconnectByKubeconfig(
resp.Context(),
@@ -123,19 +129,24 @@ func disconnectByKubeconfig(ctx context.Context, svr *Server, kubeconfigBytes st
if err != nil {
return err
}
disconnect(ctx, svr, connect)
connectionID, err := util.GetConnectionID(ctx, connect.GetClientset().CoreV1().Namespaces(), connect.Namespace)
if err != nil {
return err
}
disconnect(ctx, svr, connectionID)
if svr.currentConnectionID == connectionID {
for _, connection := range svr.connections {
svr.currentConnectionID = connection.GetConnectionID()
}
}
return nil
}
func disconnect(ctx context.Context, svr *Server, connect *handler.ConnectOptions) {
func disconnect(ctx context.Context, svr *Server, connectionID string) {
for i := 0; i < len(svr.connections); i++ {
options := svr.connections[i]
isSameCluster, _ := util.IsSameConnection(
ctx,
options.GetClientset().CoreV1(), options.OriginNamespace,
connect.GetClientset().CoreV1(), connect.Namespace,
)
if isSameCluster {
id, _ := util.GetConnectionID(ctx, options.GetClientset().CoreV1().Namespaces(), options.OriginNamespace)
if id == connectionID {
plog.G(ctx).Infof("Disconnecting from the cluster...")
options.Cleanup(ctx)
svr.connections = append(svr.connections[:i], svr.connections[i+1:]...)

View File

@@ -63,7 +63,7 @@ func (w *wsHandler) handle(c context.Context, lite bool) {
defer cli.Close()
if !lite {
err = w.createTwoWayTUNTunnel(ctx, cli)
err = w.createTunnel(ctx, cli)
if err != nil {
return
}
@@ -80,7 +80,7 @@ func (w *wsHandler) handle(c context.Context, lite bool) {
return
}
func (w *wsHandler) createTwoWayTUNTunnel(ctx context.Context, cli *ssh.Client) error {
func (w *wsHandler) createTunnel(ctx context.Context, cli *ssh.Client) error {
err := w.installKubevpnOnRemote(ctx, cli)
if err != nil {
//w.Log("Install kubevpn error: %v", err)
@@ -126,8 +126,8 @@ func (w *wsHandler) createTwoWayTUNTunnel(ctx context.Context, cli *ssh.Client)
w.Log("Failed to parse server IP %s, stderr: %s: %v", string(serverIP), string(stderr), err)
return err
}
msg := fmt.Sprintf("| You can use client: %s to communicate with server: %s |", clientIP.IP.String(), ip.String())
w.PrintLine(msg)
msg := util.PrintStr(fmt.Sprintf("You can use client: %s to communicate with server: %s", clientIP.IP.String(), ip.String()))
w.Log(msg)
w.cidr = append(w.cidr, string(serverIP))
r := core.Route{
Listeners: []string{
@@ -149,7 +149,7 @@ func (w *wsHandler) createTwoWayTUNTunnel(ctx context.Context, cli *ssh.Client)
plog.G(ctx).Info("Connected private safe tunnel")
go func() {
for ctx.Err() == nil {
util.PingOnce(ctx, clientIP.IP.String(), ip.String())
_, _ = util.Ping(ctx, clientIP.IP.String(), ip.String())
time.Sleep(time.Second * 15)
}
}()
@@ -346,13 +346,6 @@ func (w *wsHandler) Log(format string, a ...any) {
plog.G(context.Background()).Infof(format, a...)
}
func (w *wsHandler) PrintLine(msg string) {
line := "+" + strings.Repeat("-", len(msg)-2) + "+"
w.Log(line)
w.Log(msg)
w.Log(line)
}
var SessionMap = make(map[string]*ssh.Session)
var CondReady = make(map[string]context.Context)

View File

@@ -83,25 +83,6 @@ func GetTunDeviceIP(tunName string) (net.IP, net.IP, net.IP, error) {
return srcIPv4, srcIPv6, dockerSrcIPv4, nil
}
func PingOnce(ctx context.Context, srcIP, dstIP string) (bool, error) {
pinger, err := probing.NewPinger(dstIP)
if err != nil {
return false, err
}
pinger.Source = srcIP
pinger.SetLogger(nil)
pinger.SetPrivileged(true)
pinger.Count = 1
pinger.Timeout = time.Second * 1
pinger.ResolveTimeout = time.Second * 1
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
return false, err
}
stat := pinger.Statistics()
return stat.PacketsRecv == stat.PacketsSent, err
}
func Ping(ctx context.Context, srcIP, dstIP string) (bool, error) {
pinger, err := probing.NewPinger(dstIP)
if err != nil {