hotfix: close chan (#245)

This commit is contained in:
naison
2024-05-13 19:58:56 +08:00
committed by GitHub
parent e7f00f5899
commit 3e51bf0f4d
15 changed files with 163 additions and 46 deletions

View File

@@ -3,13 +3,17 @@ package cmds
import (
"crypto/rand"
"encoding/base64"
"errors"
"net/http"
"github.com/spf13/cobra"
cmdutil "k8s.io/kubectl/pkg/cmd/util"
"k8s.io/kubectl/pkg/util/i18n"
"k8s.io/kubectl/pkg/util/templates"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon"
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
func CmdDaemon(_ cmdutil.Factory) *cobra.Command {
@@ -24,10 +28,21 @@ func CmdDaemon(_ cmdutil.Factory) *cobra.Command {
return err
}
opt.ID = base64.URLEncoding.EncodeToString(b)
if opt.IsSudo {
go util.StartupPProf(config.SudoPProfPort)
} else {
go util.StartupPProf(config.PProfPort)
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(cmd *cobra.Command, args []string) (err error) {
defer opt.Stop()
defer func() {
if errors.Is(err, http.ErrServerClosed) {
err = nil
}
}()
return opt.Start(cmd.Context())
},
Hidden: true,

View File

@@ -83,7 +83,8 @@ const (
ManageBy = konfig.ManagedbyLabelKey
// pprof port
PProfPort = 32345
PProfPort = 32345
SudoPProfPort = 33345
// startup by KubeVPN
EnvStartSudoKubeVPNByKubeVPN = "DEPTH_SIGNED_BY_NAISON"

View File

@@ -104,7 +104,7 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) {
} else {
log.Debugf("[tcpserver] new routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
}
h.ch <- dgram
util.SafeWrite(h.ch, dgram)
}
}

View File

@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config.Engine, in chan<- *DataElem, out chan *DataElem) stack.LinkEndpoint {
@@ -37,7 +38,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
i := config.LPool.Get().([]byte)[:]
n := copy(i, bb)
bb = nil
out <- NewDataElem(i[:], n, nil, nil)
util.SafeWrite(out, NewDataElem(i[:], n, nil, nil))
}
}
}()
@@ -49,7 +50,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
read, err := tun.Read(bytes[:])
if err != nil {
if errors.Is(err, os.ErrClosed) {
log.Errorf("[TUN] Error: tun device closed")
return
}
// if context is done
@@ -111,7 +111,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
log.Debugf("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
} else {
log.Debugf("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
in <- NewDataElem(bytes[:], read, src, dst)
util.SafeWrite(in, NewDataElem(bytes[:], read, src, dst))
}
}
}()
@@ -121,7 +121,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
config.LPool.Put(elem.Data()[:])
if err != nil {
if errors.Is(err, os.ErrClosed) {
log.Errorf("[TUN] Error: tun device closed")
return
}
// if context is done

View File

@@ -193,10 +193,10 @@ func (d *Device) readFromTun() {
return
}
if n != 0 {
d.tunInboundRaw <- &DataElem{
util.SafeWrite(d.tunInboundRaw, &DataElem{
data: b[:],
length: n,
}
})
}
}
}
@@ -238,12 +238,16 @@ func (d *Device) parseIPHeader(ctx context.Context) {
}
log.Debugf("[tun] %s --> %s, length: %d", e.src, e.dst, e.length)
d.tunInbound <- e
util.SafeWrite(d.tunInbound, e)
}
}
func (d *Device) Close() {
d.tun.Close()
util.SafeClose(d.tunInbound)
util.SafeClose(d.tunOutbound)
util.SafeClose(d.tunInboundRaw)
util.SafeClose(Chan)
}
func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
@@ -300,12 +304,12 @@ func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
} else {
src, dst = srcIPv6, config.RouterIP6
}
in <- &DataElem{
util.SafeWrite(in, &DataElem{
data: data[:],
length: length,
src: src,
dst: dst,
}
})
}
time.Sleep(time.Second)
}

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
@@ -24,7 +25,9 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
engine := h.node.Get(config.ConfigKubeVPNTransportEngine)
endpoint := NewTunEndpoint(ctx, tun, uint32(config.DefaultMTU), config.Engine(engine), in, out)
stack := NewStack(ctx, endpoint)
go stack.Wait()
defer stack.Destroy()
defer util.SafeClose(in)
defer util.SafeClose(out)
d := &ClientDevice{
tun: tun,
@@ -84,7 +87,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut
go func() {
for e := range tunInbound {
if e.src.Equal(e.dst) {
tunOutbound <- e
util.SafeWrite(tunOutbound, e)
continue
}
_, err := packetConn.WriteTo(e.data[:e.length], remoteAddr)
@@ -104,7 +107,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut
errChan <- errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr))
return
}
tunOutbound <- &DataElem{data: b[:], length: n}
util.SafeWrite(tunOutbound, &DataElem{data: b[:], length: n})
}
}()

View File

@@ -41,7 +41,6 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
var sshConf = util.ParseSshFromRPC(req.SshJump)
var transferImage = req.TransferImage
go util.StartupPProf(config.PProfPort)
defaultlog.Default().SetOutput(io.Discard)
if transferImage {
err = util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)

View File

@@ -61,7 +61,6 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
var sshConf = util.ParseSshFromRPC(req.SshJump)
var transferImage = req.TransferImage
go util.StartupPProf(config.PProfPort)
defaultlog.Default().SetOutput(io.Discard)
if transferImage {
err := util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)

View File

@@ -48,7 +48,7 @@ func (w *wsHandler) handle(ctx context.Context) {
ctx, f := context.WithCancel(ctx)
defer f()
cli, err := util.DialSshRemote(w.sshConfig)
cli, err := util.DialSshRemote(ctx, w.sshConfig)
if err != nil {
w.Log("Dial ssh remote error: %v", err)
return

View File

@@ -69,6 +69,12 @@ func (c *ConnectOptions) Cleanup() {
log.Errorf("can not update ref-count: %v", err)
}
}
// leave proxy resources
err := c.LeaveProxyResources(ctx)
if err != nil {
log.Errorf("leave proxy resources error: %v", err)
}
for _, function := range c.getRolloutFunc() {
if function != nil {
if err := function(); err != nil {
@@ -76,11 +82,6 @@ func (c *ConnectOptions) Cleanup() {
}
}
}
// leave proxy resources
err := c.LeaveProxyResources(ctx)
if err != nil {
log.Errorf("leave proxy resources error: %v", err)
}
if c.cancel != nil {
c.cancel()
}

View File

@@ -620,6 +620,10 @@ func Run(ctx context.Context, servers []core.Server) error {
errChan <- func() error {
svr := servers[i]
defer svr.Listener.Close()
go func() {
<-ctx.Done()
svr.Listener.Close()
}()
for ctx.Err() == nil {
conn, err := svr.Listener.Accept()
if err != nil {

29
pkg/util/chan.go Normal file
View File

@@ -0,0 +1,29 @@
package util
func SafeRead[T any](c chan T) (T, bool) {
defer func() {
if r := recover(); r != nil {
}
}()
tt, ok := <-c
return tt, ok
}
func SafeWrite[T any](c chan<- T, value T) {
defer func() {
if r := recover(); r != nil {
}
}()
select {
case c <- value:
default:
}
}
func SafeClose[T any](c chan T) {
defer func() {
if r := recover(); r != nil {
}
}()
close(c)
}

23
pkg/util/chan_test.go Normal file
View File

@@ -0,0 +1,23 @@
package util
import (
"fmt"
"testing"
"time"
)
func TestChanClose(t *testing.T) {
c := make(chan any)
close(c)
SafeWrite(c, nil)
c = make(chan any)
go func() {
time.AfterFunc(time.Second*3, func() {
close(c)
})
}()
for a := range c {
fmt.Printf("%v", a)
}
}

View File

@@ -118,7 +118,7 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge
// transfer image to remote
var sshClient *ssh.Client
sshClient, err = DialSshRemote(conf)
sshClient, err = DialSshRemote(ctx, conf)
if err != nil {
return err
}

View File

@@ -114,7 +114,7 @@ func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) {
}
// DialSshRemote https://github.com/golang/go/issues/21478
func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) {
func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, err error) {
defer func() {
if err != nil {
if remote != nil {
@@ -124,21 +124,24 @@ func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) {
}()
if conf.ConfigAlias != "" {
remote, err = conf.AliasRecursion()
remote, err = conf.AliasRecursion(ctx)
} else if conf.Jump != "" {
remote, err = conf.JumpRecursion()
remote, err = conf.JumpRecursion(ctx)
} else {
remote, err = conf.Dial()
remote, err = conf.Dial(ctx)
}
// ref: https://github.com/golang/go/issues/21478
if err == nil {
go func() {
ticker := time.NewTicker(time.Second * 15)
defer ticker.Stop()
defer remote.Close()
for range ticker.C {
for ctx.Err() == nil {
time.Sleep(time.Second * 15)
_, _, err := remote.SendRequest("keepalive@golang.org", true, nil)
if err == nil || err.Error() == "request failed" {
// Any response is a success.
continue
}
if err != nil {
return
}
@@ -234,7 +237,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
return ssh.PublicKeys(key), nil
}
func copyStream(local net.Conn, remote net.Conn) {
func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
chDone := make(chan bool, 2)
// start remote -> local data transfer
@@ -265,10 +268,15 @@ func copyStream(local net.Conn, remote net.Conn) {
}
}()
<-chDone
select {
case <-chDone:
return
case <-ctx.Done():
return
}
}
func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client, err error) {
var name = config.ConfigAlias
var jumper = "ProxyJump"
var bastionList = []SshConfig{GetBastion(name, config)}
@@ -283,12 +291,12 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
}
for i := len(bastionList) - 1; i >= 0; i-- {
if client == nil {
client, err = bastionList[i].Dial()
client, err = bastionList[i].Dial(ctx)
if err != nil {
return
}
} else {
client, err = JumpTo(client, bastionList[i])
client, err = JumpTo(ctx, client, bastionList[i])
if err != nil {
return
}
@@ -297,7 +305,7 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
return
}
func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client, err error) {
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
var sshConf = &SshConfig{}
AddSshFlags(flags, sshConf)
@@ -306,7 +314,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
return nil, err
}
var baseClient *ssh.Client
baseClient, err = DialSshRemote(sshConf)
baseClient, err = DialSshRemote(ctx, sshConf)
if err != nil {
return nil, err
}
@@ -331,7 +339,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
}
for _, sshConfig := range bastionList {
client, err = JumpTo(baseClient, sshConfig)
client, err = JumpTo(ctx, baseClient, sshConfig)
if err != nil {
return
}
@@ -374,7 +382,7 @@ func GetBastion(name string, defaultValue SshConfig) SshConfig {
return config
}
func (config SshConfig) Dial() (*ssh.Client, error) {
func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error) {
if strings.Index(config.Addr, ":") < 0 {
// use default ssh port 22
config.Addr = net.JoinHostPort(config.Addr, "22")
@@ -384,16 +392,31 @@ func (config SshConfig) Dial() (*ssh.Client, error) {
if err != nil {
return nil, err
}
return ssh.Dial("tcp", config.Addr, &ssh.ClientConfig{
conn, err := net.DialTimeout("tcp", config.Addr, time.Second*10)
if err != nil {
return nil, err
}
go func() {
<-ctx.Done()
conn.Close()
if client != nil {
client.Close()
}
}()
c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{
User: config.User,
Auth: authMethod,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
BannerCallback: ssh.BannerDisplayStderr(),
Timeout: time.Second * 10,
})
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
}
func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
if strings.Index(to.Addr, ":") < 0 {
// use default ssh port 22
to.Addr = net.JoinHostPort(to.Addr, "22")
@@ -410,6 +433,14 @@ func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
if err != nil {
return
}
go func() {
<-ctx.Done()
conn.Close()
if client != nil {
client.Close()
}
bClient.Close()
}()
defer func() {
if err != nil {
if client != nil {
@@ -495,7 +526,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
}
sshClient.Close()
}
sshClient, err = DialSshRemote(conf)
sshClient, err = DialSshRemote(ctx, conf)
if err != nil {
log.Errorf("failed to dial remote ssh server: %v", err)
return nil, err
@@ -505,11 +536,20 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
go func() {
defer localListen.Close()
go func() {
<-ctx.Done()
localListen.Close()
if sshClient != nil {
sshClient.Close()
}
}()
for ctx.Err() == nil {
localConn, err := localListen.Accept()
if err != nil {
log.Errorf("failed to accept conn: %v", err)
if !errors.Is(err, net.ErrClosed) {
log.Errorf("failed to accept conn: %v", err)
}
return
}
go func() {
@@ -521,7 +561,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
return
}
defer remoteConn.Close()
copyStream(localConn, remoteConn)
copyStream(ctx, localConn, remoteConn)
}()
}
}()
@@ -551,7 +591,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
// pre-check network ip connect
var cli *ssh.Client
cli, err = DialSshRemote(conf)
cli, err = DialSshRemote(ctx, conf)
if err != nil {
return
}