hotfix: cleanup in time when connect lite mode (#243)

This commit is contained in:
naison
2024-05-13 10:14:54 +08:00
committed by GitHub
parent 70d5723e97
commit e7f00f5899
6 changed files with 33 additions and 14 deletions

View File

@@ -124,6 +124,11 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
log.Errorf("[TUN] Error: tun device closed") log.Errorf("[TUN] Error: tun device closed")
return return
} }
// if context is done
if ctx.Err() != nil {
log.Errorf("[TUN]: write to tun error: %v, context is done: %v", err, ctx.Err())
return
}
log.Errorf("[TUN] Error: failed to write data to tun device: %v", err) log.Errorf("[TUN] Error: failed to write data to tun device: %v", err)
continue continue
} }

View File

@@ -215,8 +215,14 @@ func (d *Device) writeToTun() {
} }
} }
func (d *Device) parseIPHeader() { func (d *Device) parseIPHeader(ctx context.Context) {
for e := range d.tunInboundRaw { for e := range d.tunInboundRaw {
select {
case <-ctx.Done():
return
default:
}
if util.IsIPv4(e.data[:e.length]) { if util.IsIPv4(e.data[:e.length]) {
// ipv4.ParseHeader // ipv4.ParseHeader
b := e.data[:e.length] b := e.data[:e.length]
@@ -240,7 +246,7 @@ func (d *Device) Close() {
d.tun.Close() d.tun.Close()
} }
func heartbeats(tun net.Conn, in chan<- *DataElem) { func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
conn, err := util.GetTunDeviceByConn(tun) conn, err := util.GetTunDeviceByConn(tun)
if err != nil { if err != nil {
log.Errorf("get tun device error: %s", err.Error()) log.Errorf("get tun device error: %s", err.Error())
@@ -264,6 +270,12 @@ func heartbeats(tun net.Conn, in chan<- *DataElem) {
defer ticker.Stop() defer ticker.Stop()
for ; true; <-ticker.C { for ; true; <-ticker.C {
select {
case <-ctx.Done():
return
default:
}
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
if bytes == nil { if bytes == nil {
bytes, err = genICMPPacket(srcIPv4, config.RouterIP) bytes, err = genICMPPacket(srcIPv4, config.RouterIP)
@@ -352,10 +364,10 @@ func genICMPPacketIPv6(src net.IP, dst net.IP) ([]byte, error) {
func (d *Device) Start(ctx context.Context) { func (d *Device) Start(ctx context.Context) {
go d.readFromTun() go d.readFromTun()
go d.parseIPHeader() go d.parseIPHeader(ctx)
go d.tunInboundHandler(d.tunInbound, d.tunOutbound) go d.tunInboundHandler(d.tunInbound, d.tunOutbound)
go d.writeToTun() go d.writeToTun()
go heartbeats(d.tun, d.tunInbound) go heartbeats(ctx, d.tun, d.tunInbound)
select { select {
case err := <-d.chExit: case err := <-d.chExit:
@@ -381,7 +393,7 @@ func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) {
chExit: h.chExit, chExit: h.chExit,
} }
device.SetTunInboundHandler(func(tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem) { device.SetTunInboundHandler(func(tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem) {
for { for ctx.Err() == nil {
packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", h.node.Addr) packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", h.node.Addr)
if err != nil { if err != nil {
log.Debugf("[udp] can not listen %s, err: %v", h.node.Addr, err) log.Debugf("[udp] can not listen %s, err: %v", h.node.Addr, err)

View File

@@ -127,7 +127,7 @@ type ClientDevice struct {
func (d *ClientDevice) Start(ctx context.Context) { func (d *ClientDevice) Start(ctx context.Context) {
go d.tunInboundHandler(d.tunInbound, d.tunOutbound) go d.tunInboundHandler(d.tunInbound, d.tunOutbound)
go heartbeats(d.tun, d.tunInbound) go heartbeats(ctx, d.tun, d.tunInbound)
select { select {
case err := <-d.chExit: case err := <-d.chExit:

View File

@@ -15,7 +15,7 @@ import (
"github.com/wencaiwulue/kubevpn/v2/pkg/util" "github.com/wencaiwulue/kubevpn/v2/pkg/util"
) )
func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectForkServer) error { func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectForkServer) (err error) {
defer func() { defer func() {
log.SetOutput(svr.LogFile) log.SetOutput(svr.LogFile)
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
@@ -44,7 +44,7 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
go util.StartupPProf(config.PProfPort) go util.StartupPProf(config.PProfPort)
defaultlog.Default().SetOutput(io.Discard) defaultlog.Default().SetOutput(io.Discard)
if transferImage { if transferImage {
err := util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out) err = util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)
if err != nil { if err != nil {
return err return err
} }
@@ -64,6 +64,12 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
sshCancel() sshCancel()
return nil return nil
}) })
defer func() {
if err != nil {
connect.Cleanup()
}
}()
var path string var path string
path, err = util.SshJump(sshCtx, sshConf, flags, false) path, err = util.SshJump(sshCtx, sshConf, flags, false)
if err != nil { if err != nil {
@@ -86,11 +92,10 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
err = connect.DoConnect(sshCtx, true) err = connect.DoConnect(sshCtx, true)
if err != nil { if err != nil {
log.Errorf("do connect error: %v", err) log.Errorf("do connect error: %v", err)
connect.Cleanup()
return err return err
} }
svr.secondaryConnect = append(svr.secondaryConnect, connect)
svr.secondaryConnect = append(svr.secondaryConnect, connect)
return nil return nil
} }

View File

@@ -18,7 +18,7 @@ import (
) )
func (svr *Server) Get(ctx context.Context, req *rpc.GetRequest) (*rpc.GetResponse, error) { func (svr *Server) Get(ctx context.Context, req *rpc.GetRequest) (*rpc.GetResponse, error) {
if svr.connect == nil { if svr.connect == nil || svr.connect.Context() == nil {
return nil, errors.New("not connected") return nil, errors.New("not connected")
} }
if svr.gr == nil { if svr.gr == nil {

View File

@@ -85,9 +85,6 @@ type ConnectOptions struct {
} }
func (c *ConnectOptions) Context() context.Context { func (c *ConnectOptions) Context() context.Context {
if c.ctx == nil {
c.ctx, c.cancel = context.WithCancel(context.Background())
}
return c.ctx return c.ctx
} }