mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-19 05:34:38 +08:00
hotfix: close chan (#245)
This commit is contained in:
@@ -3,13 +3,17 @@ package cmds
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
cmdutil "k8s.io/kubectl/pkg/cmd/util"
|
||||
"k8s.io/kubectl/pkg/util/i18n"
|
||||
"k8s.io/kubectl/pkg/util/templates"
|
||||
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon"
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||
)
|
||||
|
||||
func CmdDaemon(_ cmdutil.Factory) *cobra.Command {
|
||||
@@ -24,10 +28,21 @@ func CmdDaemon(_ cmdutil.Factory) *cobra.Command {
|
||||
return err
|
||||
}
|
||||
opt.ID = base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
if opt.IsSudo {
|
||||
go util.StartupPProf(config.SudoPProfPort)
|
||||
} else {
|
||||
go util.StartupPProf(config.PProfPort)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
RunE: func(cmd *cobra.Command, args []string) (err error) {
|
||||
defer opt.Stop()
|
||||
defer func() {
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
err = nil
|
||||
}
|
||||
}()
|
||||
return opt.Start(cmd.Context())
|
||||
},
|
||||
Hidden: true,
|
||||
|
@@ -83,7 +83,8 @@ const (
|
||||
ManageBy = konfig.ManagedbyLabelKey
|
||||
|
||||
// pprof port
|
||||
PProfPort = 32345
|
||||
PProfPort = 32345
|
||||
SudoPProfPort = 33345
|
||||
|
||||
// startup by KubeVPN
|
||||
EnvStartSudoKubeVPNByKubeVPN = "DEPTH_SIGNED_BY_NAISON"
|
||||
|
@@ -104,7 +104,7 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) {
|
||||
} else {
|
||||
log.Debugf("[tcpserver] new routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
|
||||
}
|
||||
h.ch <- dgram
|
||||
util.SafeWrite(h.ch, dgram)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||
)
|
||||
|
||||
func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config.Engine, in chan<- *DataElem, out chan *DataElem) stack.LinkEndpoint {
|
||||
@@ -37,7 +38,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
||||
i := config.LPool.Get().([]byte)[:]
|
||||
n := copy(i, bb)
|
||||
bb = nil
|
||||
out <- NewDataElem(i[:], n, nil, nil)
|
||||
util.SafeWrite(out, NewDataElem(i[:], n, nil, nil))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -49,7 +50,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
||||
read, err := tun.Read(bytes[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) {
|
||||
log.Errorf("[TUN] Error: tun device closed")
|
||||
return
|
||||
}
|
||||
// if context is done
|
||||
@@ -111,7 +111,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
||||
log.Debugf("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||
} else {
|
||||
log.Debugf("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||
in <- NewDataElem(bytes[:], read, src, dst)
|
||||
util.SafeWrite(in, NewDataElem(bytes[:], read, src, dst))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -121,7 +121,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
||||
config.LPool.Put(elem.Data()[:])
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) {
|
||||
log.Errorf("[TUN] Error: tun device closed")
|
||||
return
|
||||
}
|
||||
// if context is done
|
||||
|
@@ -193,10 +193,10 @@ func (d *Device) readFromTun() {
|
||||
return
|
||||
}
|
||||
if n != 0 {
|
||||
d.tunInboundRaw <- &DataElem{
|
||||
util.SafeWrite(d.tunInboundRaw, &DataElem{
|
||||
data: b[:],
|
||||
length: n,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -238,12 +238,16 @@ func (d *Device) parseIPHeader(ctx context.Context) {
|
||||
}
|
||||
|
||||
log.Debugf("[tun] %s --> %s, length: %d", e.src, e.dst, e.length)
|
||||
d.tunInbound <- e
|
||||
util.SafeWrite(d.tunInbound, e)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Device) Close() {
|
||||
d.tun.Close()
|
||||
util.SafeClose(d.tunInbound)
|
||||
util.SafeClose(d.tunOutbound)
|
||||
util.SafeClose(d.tunInboundRaw)
|
||||
util.SafeClose(Chan)
|
||||
}
|
||||
|
||||
func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
|
||||
@@ -300,12 +304,12 @@ func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
|
||||
} else {
|
||||
src, dst = srcIPv6, config.RouterIP6
|
||||
}
|
||||
in <- &DataElem{
|
||||
util.SafeWrite(in, &DataElem{
|
||||
data: data[:],
|
||||
length: length,
|
||||
src: src,
|
||||
dst: dst,
|
||||
}
|
||||
})
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||
)
|
||||
|
||||
func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
|
||||
@@ -24,7 +25,9 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
|
||||
engine := h.node.Get(config.ConfigKubeVPNTransportEngine)
|
||||
endpoint := NewTunEndpoint(ctx, tun, uint32(config.DefaultMTU), config.Engine(engine), in, out)
|
||||
stack := NewStack(ctx, endpoint)
|
||||
go stack.Wait()
|
||||
defer stack.Destroy()
|
||||
defer util.SafeClose(in)
|
||||
defer util.SafeClose(out)
|
||||
|
||||
d := &ClientDevice{
|
||||
tun: tun,
|
||||
@@ -84,7 +87,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut
|
||||
go func() {
|
||||
for e := range tunInbound {
|
||||
if e.src.Equal(e.dst) {
|
||||
tunOutbound <- e
|
||||
util.SafeWrite(tunOutbound, e)
|
||||
continue
|
||||
}
|
||||
_, err := packetConn.WriteTo(e.data[:e.length], remoteAddr)
|
||||
@@ -104,7 +107,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut
|
||||
errChan <- errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr))
|
||||
return
|
||||
}
|
||||
tunOutbound <- &DataElem{data: b[:], length: n}
|
||||
util.SafeWrite(tunOutbound, &DataElem{data: b[:], length: n})
|
||||
}
|
||||
}()
|
||||
|
||||
|
@@ -41,7 +41,6 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
|
||||
var sshConf = util.ParseSshFromRPC(req.SshJump)
|
||||
var transferImage = req.TransferImage
|
||||
|
||||
go util.StartupPProf(config.PProfPort)
|
||||
defaultlog.Default().SetOutput(io.Discard)
|
||||
if transferImage {
|
||||
err = util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)
|
||||
|
@@ -61,7 +61,6 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
|
||||
var sshConf = util.ParseSshFromRPC(req.SshJump)
|
||||
var transferImage = req.TransferImage
|
||||
|
||||
go util.StartupPProf(config.PProfPort)
|
||||
defaultlog.Default().SetOutput(io.Discard)
|
||||
if transferImage {
|
||||
err := util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)
|
||||
|
@@ -48,7 +48,7 @@ func (w *wsHandler) handle(ctx context.Context) {
|
||||
ctx, f := context.WithCancel(ctx)
|
||||
defer f()
|
||||
|
||||
cli, err := util.DialSshRemote(w.sshConfig)
|
||||
cli, err := util.DialSshRemote(ctx, w.sshConfig)
|
||||
if err != nil {
|
||||
w.Log("Dial ssh remote error: %v", err)
|
||||
return
|
||||
|
@@ -69,6 +69,12 @@ func (c *ConnectOptions) Cleanup() {
|
||||
log.Errorf("can not update ref-count: %v", err)
|
||||
}
|
||||
}
|
||||
// leave proxy resources
|
||||
err := c.LeaveProxyResources(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("leave proxy resources error: %v", err)
|
||||
}
|
||||
|
||||
for _, function := range c.getRolloutFunc() {
|
||||
if function != nil {
|
||||
if err := function(); err != nil {
|
||||
@@ -76,11 +82,6 @@ func (c *ConnectOptions) Cleanup() {
|
||||
}
|
||||
}
|
||||
}
|
||||
// leave proxy resources
|
||||
err := c.LeaveProxyResources(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("leave proxy resources error: %v", err)
|
||||
}
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
@@ -620,6 +620,10 @@ func Run(ctx context.Context, servers []core.Server) error {
|
||||
errChan <- func() error {
|
||||
svr := servers[i]
|
||||
defer svr.Listener.Close()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
svr.Listener.Close()
|
||||
}()
|
||||
for ctx.Err() == nil {
|
||||
conn, err := svr.Listener.Accept()
|
||||
if err != nil {
|
||||
|
29
pkg/util/chan.go
Normal file
29
pkg/util/chan.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package util
|
||||
|
||||
func SafeRead[T any](c chan T) (T, bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
}
|
||||
}()
|
||||
tt, ok := <-c
|
||||
return tt, ok
|
||||
}
|
||||
|
||||
func SafeWrite[T any](c chan<- T, value T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case c <- value:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func SafeClose[T any](c chan T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
}
|
||||
}()
|
||||
close(c)
|
||||
}
|
23
pkg/util/chan_test.go
Normal file
23
pkg/util/chan_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestChanClose(t *testing.T) {
|
||||
c := make(chan any)
|
||||
close(c)
|
||||
SafeWrite(c, nil)
|
||||
|
||||
c = make(chan any)
|
||||
go func() {
|
||||
time.AfterFunc(time.Second*3, func() {
|
||||
close(c)
|
||||
})
|
||||
}()
|
||||
for a := range c {
|
||||
fmt.Printf("%v", a)
|
||||
}
|
||||
}
|
@@ -118,7 +118,7 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge
|
||||
|
||||
// transfer image to remote
|
||||
var sshClient *ssh.Client
|
||||
sshClient, err = DialSshRemote(conf)
|
||||
sshClient, err = DialSshRemote(ctx, conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -114,7 +114,7 @@ func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) {
|
||||
}
|
||||
|
||||
// DialSshRemote https://github.com/golang/go/issues/21478
|
||||
func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) {
|
||||
func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if remote != nil {
|
||||
@@ -124,21 +124,24 @@ func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) {
|
||||
}()
|
||||
|
||||
if conf.ConfigAlias != "" {
|
||||
remote, err = conf.AliasRecursion()
|
||||
remote, err = conf.AliasRecursion(ctx)
|
||||
} else if conf.Jump != "" {
|
||||
remote, err = conf.JumpRecursion()
|
||||
remote, err = conf.JumpRecursion(ctx)
|
||||
} else {
|
||||
remote, err = conf.Dial()
|
||||
remote, err = conf.Dial(ctx)
|
||||
}
|
||||
|
||||
// ref: https://github.com/golang/go/issues/21478
|
||||
if err == nil {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second * 15)
|
||||
defer ticker.Stop()
|
||||
defer remote.Close()
|
||||
for range ticker.C {
|
||||
for ctx.Err() == nil {
|
||||
time.Sleep(time.Second * 15)
|
||||
_, _, err := remote.SendRequest("keepalive@golang.org", true, nil)
|
||||
if err == nil || err.Error() == "request failed" {
|
||||
// Any response is a success.
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -234,7 +237,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
|
||||
return ssh.PublicKeys(key), nil
|
||||
}
|
||||
|
||||
func copyStream(local net.Conn, remote net.Conn) {
|
||||
func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
|
||||
chDone := make(chan bool, 2)
|
||||
|
||||
// start remote -> local data transfer
|
||||
@@ -265,10 +268,15 @@ func copyStream(local net.Conn, remote net.Conn) {
|
||||
}
|
||||
}()
|
||||
|
||||
<-chDone
|
||||
select {
|
||||
case <-chDone:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
|
||||
func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client, err error) {
|
||||
var name = config.ConfigAlias
|
||||
var jumper = "ProxyJump"
|
||||
var bastionList = []SshConfig{GetBastion(name, config)}
|
||||
@@ -283,12 +291,12 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
|
||||
}
|
||||
for i := len(bastionList) - 1; i >= 0; i-- {
|
||||
if client == nil {
|
||||
client, err = bastionList[i].Dial()
|
||||
client, err = bastionList[i].Dial(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
client, err = JumpTo(client, bastionList[i])
|
||||
client, err = JumpTo(ctx, client, bastionList[i])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -297,7 +305,7 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
|
||||
func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client, err error) {
|
||||
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
||||
var sshConf = &SshConfig{}
|
||||
AddSshFlags(flags, sshConf)
|
||||
@@ -306,7 +314,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
|
||||
return nil, err
|
||||
}
|
||||
var baseClient *ssh.Client
|
||||
baseClient, err = DialSshRemote(sshConf)
|
||||
baseClient, err = DialSshRemote(ctx, sshConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -331,7 +339,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
|
||||
}
|
||||
|
||||
for _, sshConfig := range bastionList {
|
||||
client, err = JumpTo(baseClient, sshConfig)
|
||||
client, err = JumpTo(ctx, baseClient, sshConfig)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -374,7 +382,7 @@ func GetBastion(name string, defaultValue SshConfig) SshConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
func (config SshConfig) Dial() (*ssh.Client, error) {
|
||||
func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error) {
|
||||
if strings.Index(config.Addr, ":") < 0 {
|
||||
// use default ssh port 22
|
||||
config.Addr = net.JoinHostPort(config.Addr, "22")
|
||||
@@ -384,16 +392,31 @@ func (config SshConfig) Dial() (*ssh.Client, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.Dial("tcp", config.Addr, &ssh.ClientConfig{
|
||||
conn, err := net.DialTimeout("tcp", config.Addr, time.Second*10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}()
|
||||
c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{
|
||||
User: config.User,
|
||||
Auth: authMethod,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
BannerCallback: ssh.BannerDisplayStderr(),
|
||||
Timeout: time.Second * 10,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.NewClient(c, chans, reqs), nil
|
||||
}
|
||||
|
||||
func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
|
||||
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
|
||||
if strings.Index(to.Addr, ":") < 0 {
|
||||
// use default ssh port 22
|
||||
to.Addr = net.JoinHostPort(to.Addr, "22")
|
||||
@@ -410,6 +433,14 @@ func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
bClient.Close()
|
||||
}()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if client != nil {
|
||||
@@ -495,7 +526,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
}
|
||||
sshClient.Close()
|
||||
}
|
||||
sshClient, err = DialSshRemote(conf)
|
||||
sshClient, err = DialSshRemote(ctx, conf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to dial remote ssh server: %v", err)
|
||||
return nil, err
|
||||
@@ -505,11 +536,20 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
|
||||
go func() {
|
||||
defer localListen.Close()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
localListen.Close()
|
||||
if sshClient != nil {
|
||||
sshClient.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
for ctx.Err() == nil {
|
||||
localConn, err := localListen.Accept()
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept conn: %v", err)
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
log.Errorf("failed to accept conn: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
@@ -521,7 +561,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
copyStream(localConn, remoteConn)
|
||||
copyStream(ctx, localConn, remoteConn)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
@@ -551,7 +591,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
|
||||
|
||||
// pre-check network ip connect
|
||||
var cli *ssh.Client
|
||||
cli, err = DialSshRemote(conf)
|
||||
cli, err = DialSshRemote(ctx, conf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user