diff --git a/pkg/core/tunendpoint.go b/pkg/core/tunendpoint.go index 10424fab..f5e80bd0 100755 --- a/pkg/core/tunendpoint.go +++ b/pkg/core/tunendpoint.go @@ -124,6 +124,11 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config log.Errorf("[TUN] Error: tun device closed") return } + // if context is done + if ctx.Err() != nil { + log.Errorf("[TUN]: write to tun error: %v, context is done: %v", err, ctx.Err()) + return + } log.Errorf("[TUN] Error: failed to write data to tun device: %v", err) continue } diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index 5e3b8e19..9c69a2a7 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -215,8 +215,14 @@ func (d *Device) writeToTun() { } } -func (d *Device) parseIPHeader() { +func (d *Device) parseIPHeader(ctx context.Context) { for e := range d.tunInboundRaw { + select { + case <-ctx.Done(): + return + default: + } + if util.IsIPv4(e.data[:e.length]) { // ipv4.ParseHeader b := e.data[:e.length] @@ -240,7 +246,7 @@ func (d *Device) Close() { d.tun.Close() } -func heartbeats(tun net.Conn, in chan<- *DataElem) { +func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) { conn, err := util.GetTunDeviceByConn(tun) if err != nil { log.Errorf("get tun device error: %s", err.Error()) @@ -264,6 +270,12 @@ func heartbeats(tun net.Conn, in chan<- *DataElem) { defer ticker.Stop() for ; true; <-ticker.C { + select { + case <-ctx.Done(): + return + default: + } + for i := 0; i < 4; i++ { if bytes == nil { bytes, err = genICMPPacket(srcIPv4, config.RouterIP) @@ -352,10 +364,10 @@ func genICMPPacketIPv6(src net.IP, dst net.IP) ([]byte, error) { func (d *Device) Start(ctx context.Context) { go d.readFromTun() - go d.parseIPHeader() + go d.parseIPHeader(ctx) go d.tunInboundHandler(d.tunInbound, d.tunOutbound) go d.writeToTun() - go heartbeats(d.tun, d.tunInbound) + go heartbeats(ctx, d.tun, d.tunInbound) select { case err := <-d.chExit: @@ -381,7 +393,7 @@ func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) { chExit: h.chExit, } device.SetTunInboundHandler(func(tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem) { - for { + for ctx.Err() == nil { packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", h.node.Addr) if err != nil { log.Debugf("[udp] can not listen %s, err: %v", h.node.Addr, err) diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index e9d33a64..80c3d3b8 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -127,7 +127,7 @@ type ClientDevice struct { func (d *ClientDevice) Start(ctx context.Context) { go d.tunInboundHandler(d.tunInbound, d.tunOutbound) - go heartbeats(d.tun, d.tunInbound) + go heartbeats(ctx, d.tun, d.tunInbound) select { case err := <-d.chExit: diff --git a/pkg/daemon/action/connect-fork.go b/pkg/daemon/action/connect-fork.go index 9d546252..ed9e3965 100644 --- a/pkg/daemon/action/connect-fork.go +++ b/pkg/daemon/action/connect-fork.go @@ -15,7 +15,7 @@ import ( "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) -func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectForkServer) error { +func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectForkServer) (err error) { defer func() { log.SetOutput(svr.LogFile) log.SetLevel(log.DebugLevel) @@ -44,7 +44,7 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF go util.StartupPProf(config.PProfPort) defaultlog.Default().SetOutput(io.Discard) if transferImage { - err := util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out) + err = util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out) if err != nil { return err } @@ -64,6 +64,12 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF sshCancel() return nil }) + defer func() { + if err != nil { + connect.Cleanup() + } + }() + var path string path, err = util.SshJump(sshCtx, sshConf, flags, false) if err != nil { @@ -86,11 +92,10 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF err = connect.DoConnect(sshCtx, true) if err != nil { log.Errorf("do connect error: %v", err) - connect.Cleanup() return err } - svr.secondaryConnect = append(svr.secondaryConnect, connect) + svr.secondaryConnect = append(svr.secondaryConnect, connect) return nil } diff --git a/pkg/daemon/action/get.go b/pkg/daemon/action/get.go index ff266033..6ea3fabc 100644 --- a/pkg/daemon/action/get.go +++ b/pkg/daemon/action/get.go @@ -18,7 +18,7 @@ import ( ) func (svr *Server) Get(ctx context.Context, req *rpc.GetRequest) (*rpc.GetResponse, error) { - if svr.connect == nil { + if svr.connect == nil || svr.connect.Context() == nil { return nil, errors.New("not connected") } if svr.gr == nil { diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index 2afbaa0b..11dd5321 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -85,9 +85,6 @@ type ConnectOptions struct { } func (c *ConnectOptions) Context() context.Context { - if c.ctx == nil { - c.ctx, c.cancel = context.WithCancel(context.Background()) - } return c.ctx }