mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-21 22:39:36 +08:00
hotfix: close chan (#245)
This commit is contained in:
@@ -3,13 +3,17 @@ package cmds
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
cmdutil "k8s.io/kubectl/pkg/cmd/util"
|
cmdutil "k8s.io/kubectl/pkg/cmd/util"
|
||||||
"k8s.io/kubectl/pkg/util/i18n"
|
"k8s.io/kubectl/pkg/util/i18n"
|
||||||
"k8s.io/kubectl/pkg/util/templates"
|
"k8s.io/kubectl/pkg/util/templates"
|
||||||
|
|
||||||
|
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||||
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon"
|
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon"
|
||||||
|
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CmdDaemon(_ cmdutil.Factory) *cobra.Command {
|
func CmdDaemon(_ cmdutil.Factory) *cobra.Command {
|
||||||
@@ -24,10 +28,21 @@ func CmdDaemon(_ cmdutil.Factory) *cobra.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opt.ID = base64.URLEncoding.EncodeToString(b)
|
opt.ID = base64.URLEncoding.EncodeToString(b)
|
||||||
|
|
||||||
|
if opt.IsSudo {
|
||||||
|
go util.StartupPProf(config.SudoPProfPort)
|
||||||
|
} else {
|
||||||
|
go util.StartupPProf(config.PProfPort)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) (err error) {
|
||||||
defer opt.Stop()
|
defer opt.Stop()
|
||||||
|
defer func() {
|
||||||
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
return opt.Start(cmd.Context())
|
return opt.Start(cmd.Context())
|
||||||
},
|
},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
|
@@ -83,7 +83,8 @@ const (
|
|||||||
ManageBy = konfig.ManagedbyLabelKey
|
ManageBy = konfig.ManagedbyLabelKey
|
||||||
|
|
||||||
// pprof port
|
// pprof port
|
||||||
PProfPort = 32345
|
PProfPort = 32345
|
||||||
|
SudoPProfPort = 33345
|
||||||
|
|
||||||
// startup by KubeVPN
|
// startup by KubeVPN
|
||||||
EnvStartSudoKubeVPNByKubeVPN = "DEPTH_SIGNED_BY_NAISON"
|
EnvStartSudoKubeVPNByKubeVPN = "DEPTH_SIGNED_BY_NAISON"
|
||||||
|
@@ -104,7 +104,7 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("[tcpserver] new routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
|
log.Debugf("[tcpserver] new routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
|
||||||
}
|
}
|
||||||
h.ch <- dgram
|
util.SafeWrite(h.ch, dgram)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -18,6 +18,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
|
||||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||||
|
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config.Engine, in chan<- *DataElem, out chan *DataElem) stack.LinkEndpoint {
|
func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config.Engine, in chan<- *DataElem, out chan *DataElem) stack.LinkEndpoint {
|
||||||
@@ -37,7 +38,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
|||||||
i := config.LPool.Get().([]byte)[:]
|
i := config.LPool.Get().([]byte)[:]
|
||||||
n := copy(i, bb)
|
n := copy(i, bb)
|
||||||
bb = nil
|
bb = nil
|
||||||
out <- NewDataElem(i[:], n, nil, nil)
|
util.SafeWrite(out, NewDataElem(i[:], n, nil, nil))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -49,7 +50,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
|||||||
read, err := tun.Read(bytes[:])
|
read, err := tun.Read(bytes[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) {
|
if errors.Is(err, os.ErrClosed) {
|
||||||
log.Errorf("[TUN] Error: tun device closed")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// if context is done
|
// if context is done
|
||||||
@@ -111,7 +111,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
|||||||
log.Debugf("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
log.Debugf("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
log.Debugf("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||||
in <- NewDataElem(bytes[:], read, src, dst)
|
util.SafeWrite(in, NewDataElem(bytes[:], read, src, dst))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -121,7 +121,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
|||||||
config.LPool.Put(elem.Data()[:])
|
config.LPool.Put(elem.Data()[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) {
|
if errors.Is(err, os.ErrClosed) {
|
||||||
log.Errorf("[TUN] Error: tun device closed")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// if context is done
|
// if context is done
|
||||||
|
@@ -193,10 +193,10 @@ func (d *Device) readFromTun() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if n != 0 {
|
if n != 0 {
|
||||||
d.tunInboundRaw <- &DataElem{
|
util.SafeWrite(d.tunInboundRaw, &DataElem{
|
||||||
data: b[:],
|
data: b[:],
|
||||||
length: n,
|
length: n,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -238,12 +238,16 @@ func (d *Device) parseIPHeader(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("[tun] %s --> %s, length: %d", e.src, e.dst, e.length)
|
log.Debugf("[tun] %s --> %s, length: %d", e.src, e.dst, e.length)
|
||||||
d.tunInbound <- e
|
util.SafeWrite(d.tunInbound, e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Device) Close() {
|
func (d *Device) Close() {
|
||||||
d.tun.Close()
|
d.tun.Close()
|
||||||
|
util.SafeClose(d.tunInbound)
|
||||||
|
util.SafeClose(d.tunOutbound)
|
||||||
|
util.SafeClose(d.tunInboundRaw)
|
||||||
|
util.SafeClose(Chan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
|
func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
|
||||||
@@ -300,12 +304,12 @@ func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
|
|||||||
} else {
|
} else {
|
||||||
src, dst = srcIPv6, config.RouterIP6
|
src, dst = srcIPv6, config.RouterIP6
|
||||||
}
|
}
|
||||||
in <- &DataElem{
|
util.SafeWrite(in, &DataElem{
|
||||||
data: data[:],
|
data: data[:],
|
||||||
length: length,
|
length: length,
|
||||||
src: src,
|
src: src,
|
||||||
dst: dst,
|
dst: dst,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
@@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||||
|
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
|
func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
|
||||||
@@ -24,7 +25,9 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
|
|||||||
engine := h.node.Get(config.ConfigKubeVPNTransportEngine)
|
engine := h.node.Get(config.ConfigKubeVPNTransportEngine)
|
||||||
endpoint := NewTunEndpoint(ctx, tun, uint32(config.DefaultMTU), config.Engine(engine), in, out)
|
endpoint := NewTunEndpoint(ctx, tun, uint32(config.DefaultMTU), config.Engine(engine), in, out)
|
||||||
stack := NewStack(ctx, endpoint)
|
stack := NewStack(ctx, endpoint)
|
||||||
go stack.Wait()
|
defer stack.Destroy()
|
||||||
|
defer util.SafeClose(in)
|
||||||
|
defer util.SafeClose(out)
|
||||||
|
|
||||||
d := &ClientDevice{
|
d := &ClientDevice{
|
||||||
tun: tun,
|
tun: tun,
|
||||||
@@ -84,7 +87,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut
|
|||||||
go func() {
|
go func() {
|
||||||
for e := range tunInbound {
|
for e := range tunInbound {
|
||||||
if e.src.Equal(e.dst) {
|
if e.src.Equal(e.dst) {
|
||||||
tunOutbound <- e
|
util.SafeWrite(tunOutbound, e)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
_, err := packetConn.WriteTo(e.data[:e.length], remoteAddr)
|
_, err := packetConn.WriteTo(e.data[:e.length], remoteAddr)
|
||||||
@@ -104,7 +107,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut
|
|||||||
errChan <- errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr))
|
errChan <- errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tunOutbound <- &DataElem{data: b[:], length: n}
|
util.SafeWrite(tunOutbound, &DataElem{data: b[:], length: n})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@@ -41,7 +41,6 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
|
|||||||
var sshConf = util.ParseSshFromRPC(req.SshJump)
|
var sshConf = util.ParseSshFromRPC(req.SshJump)
|
||||||
var transferImage = req.TransferImage
|
var transferImage = req.TransferImage
|
||||||
|
|
||||||
go util.StartupPProf(config.PProfPort)
|
|
||||||
defaultlog.Default().SetOutput(io.Discard)
|
defaultlog.Default().SetOutput(io.Discard)
|
||||||
if transferImage {
|
if transferImage {
|
||||||
err = util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)
|
err = util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)
|
||||||
|
@@ -61,7 +61,6 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
|
|||||||
var sshConf = util.ParseSshFromRPC(req.SshJump)
|
var sshConf = util.ParseSshFromRPC(req.SshJump)
|
||||||
var transferImage = req.TransferImage
|
var transferImage = req.TransferImage
|
||||||
|
|
||||||
go util.StartupPProf(config.PProfPort)
|
|
||||||
defaultlog.Default().SetOutput(io.Discard)
|
defaultlog.Default().SetOutput(io.Discard)
|
||||||
if transferImage {
|
if transferImage {
|
||||||
err := util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)
|
err := util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out)
|
||||||
|
@@ -48,7 +48,7 @@ func (w *wsHandler) handle(ctx context.Context) {
|
|||||||
ctx, f := context.WithCancel(ctx)
|
ctx, f := context.WithCancel(ctx)
|
||||||
defer f()
|
defer f()
|
||||||
|
|
||||||
cli, err := util.DialSshRemote(w.sshConfig)
|
cli, err := util.DialSshRemote(ctx, w.sshConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Log("Dial ssh remote error: %v", err)
|
w.Log("Dial ssh remote error: %v", err)
|
||||||
return
|
return
|
||||||
|
@@ -69,6 +69,12 @@ func (c *ConnectOptions) Cleanup() {
|
|||||||
log.Errorf("can not update ref-count: %v", err)
|
log.Errorf("can not update ref-count: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// leave proxy resources
|
||||||
|
err := c.LeaveProxyResources(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("leave proxy resources error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, function := range c.getRolloutFunc() {
|
for _, function := range c.getRolloutFunc() {
|
||||||
if function != nil {
|
if function != nil {
|
||||||
if err := function(); err != nil {
|
if err := function(); err != nil {
|
||||||
@@ -76,11 +82,6 @@ func (c *ConnectOptions) Cleanup() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// leave proxy resources
|
|
||||||
err := c.LeaveProxyResources(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("leave proxy resources error: %v", err)
|
|
||||||
}
|
|
||||||
if c.cancel != nil {
|
if c.cancel != nil {
|
||||||
c.cancel()
|
c.cancel()
|
||||||
}
|
}
|
||||||
|
@@ -620,6 +620,10 @@ func Run(ctx context.Context, servers []core.Server) error {
|
|||||||
errChan <- func() error {
|
errChan <- func() error {
|
||||||
svr := servers[i]
|
svr := servers[i]
|
||||||
defer svr.Listener.Close()
|
defer svr.Listener.Close()
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
svr.Listener.Close()
|
||||||
|
}()
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
conn, err := svr.Listener.Accept()
|
conn, err := svr.Listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
29
pkg/util/chan.go
Normal file
29
pkg/util/chan.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
func SafeRead[T any](c chan T) (T, bool) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
tt, ok := <-c
|
||||||
|
return tt, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func SafeWrite[T any](c chan<- T, value T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case c <- value:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SafeClose[T any](c chan T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
close(c)
|
||||||
|
}
|
23
pkg/util/chan_test.go
Normal file
23
pkg/util/chan_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChanClose(t *testing.T) {
|
||||||
|
c := make(chan any)
|
||||||
|
close(c)
|
||||||
|
SafeWrite(c, nil)
|
||||||
|
|
||||||
|
c = make(chan any)
|
||||||
|
go func() {
|
||||||
|
time.AfterFunc(time.Second*3, func() {
|
||||||
|
close(c)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
for a := range c {
|
||||||
|
fmt.Printf("%v", a)
|
||||||
|
}
|
||||||
|
}
|
@@ -118,7 +118,7 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge
|
|||||||
|
|
||||||
// transfer image to remote
|
// transfer image to remote
|
||||||
var sshClient *ssh.Client
|
var sshClient *ssh.Client
|
||||||
sshClient, err = DialSshRemote(conf)
|
sshClient, err = DialSshRemote(ctx, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@@ -114,7 +114,7 @@ func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DialSshRemote https://github.com/golang/go/issues/21478
|
// DialSshRemote https://github.com/golang/go/issues/21478
|
||||||
func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) {
|
func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if remote != nil {
|
if remote != nil {
|
||||||
@@ -124,21 +124,24 @@ func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if conf.ConfigAlias != "" {
|
if conf.ConfigAlias != "" {
|
||||||
remote, err = conf.AliasRecursion()
|
remote, err = conf.AliasRecursion(ctx)
|
||||||
} else if conf.Jump != "" {
|
} else if conf.Jump != "" {
|
||||||
remote, err = conf.JumpRecursion()
|
remote, err = conf.JumpRecursion(ctx)
|
||||||
} else {
|
} else {
|
||||||
remote, err = conf.Dial()
|
remote, err = conf.Dial(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ref: https://github.com/golang/go/issues/21478
|
// ref: https://github.com/golang/go/issues/21478
|
||||||
if err == nil {
|
if err == nil {
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(time.Second * 15)
|
|
||||||
defer ticker.Stop()
|
|
||||||
defer remote.Close()
|
defer remote.Close()
|
||||||
for range ticker.C {
|
for ctx.Err() == nil {
|
||||||
|
time.Sleep(time.Second * 15)
|
||||||
_, _, err := remote.SendRequest("keepalive@golang.org", true, nil)
|
_, _, err := remote.SendRequest("keepalive@golang.org", true, nil)
|
||||||
|
if err == nil || err.Error() == "request failed" {
|
||||||
|
// Any response is a success.
|
||||||
|
continue
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -234,7 +237,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
|
|||||||
return ssh.PublicKeys(key), nil
|
return ssh.PublicKeys(key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyStream(local net.Conn, remote net.Conn) {
|
func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
|
||||||
chDone := make(chan bool, 2)
|
chDone := make(chan bool, 2)
|
||||||
|
|
||||||
// start remote -> local data transfer
|
// start remote -> local data transfer
|
||||||
@@ -265,10 +268,15 @@ func copyStream(local net.Conn, remote net.Conn) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
<-chDone
|
select {
|
||||||
|
case <-chDone:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
|
func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client, err error) {
|
||||||
var name = config.ConfigAlias
|
var name = config.ConfigAlias
|
||||||
var jumper = "ProxyJump"
|
var jumper = "ProxyJump"
|
||||||
var bastionList = []SshConfig{GetBastion(name, config)}
|
var bastionList = []SshConfig{GetBastion(name, config)}
|
||||||
@@ -283,12 +291,12 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
|
|||||||
}
|
}
|
||||||
for i := len(bastionList) - 1; i >= 0; i-- {
|
for i := len(bastionList) - 1; i >= 0; i-- {
|
||||||
if client == nil {
|
if client == nil {
|
||||||
client, err = bastionList[i].Dial()
|
client, err = bastionList[i].Dial(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
client, err = JumpTo(client, bastionList[i])
|
client, err = JumpTo(ctx, client, bastionList[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -297,7 +305,7 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
|
func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client, err error) {
|
||||||
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
||||||
var sshConf = &SshConfig{}
|
var sshConf = &SshConfig{}
|
||||||
AddSshFlags(flags, sshConf)
|
AddSshFlags(flags, sshConf)
|
||||||
@@ -306,7 +314,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var baseClient *ssh.Client
|
var baseClient *ssh.Client
|
||||||
baseClient, err = DialSshRemote(sshConf)
|
baseClient, err = DialSshRemote(ctx, sshConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -331,7 +339,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, sshConfig := range bastionList {
|
for _, sshConfig := range bastionList {
|
||||||
client, err = JumpTo(baseClient, sshConfig)
|
client, err = JumpTo(ctx, baseClient, sshConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -374,7 +382,7 @@ func GetBastion(name string, defaultValue SshConfig) SshConfig {
|
|||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (config SshConfig) Dial() (*ssh.Client, error) {
|
func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error) {
|
||||||
if strings.Index(config.Addr, ":") < 0 {
|
if strings.Index(config.Addr, ":") < 0 {
|
||||||
// use default ssh port 22
|
// use default ssh port 22
|
||||||
config.Addr = net.JoinHostPort(config.Addr, "22")
|
config.Addr = net.JoinHostPort(config.Addr, "22")
|
||||||
@@ -384,16 +392,31 @@ func (config SshConfig) Dial() (*ssh.Client, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return ssh.Dial("tcp", config.Addr, &ssh.ClientConfig{
|
conn, err := net.DialTimeout("tcp", config.Addr, time.Second*10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
conn.Close()
|
||||||
|
if client != nil {
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{
|
||||||
User: config.User,
|
User: config.User,
|
||||||
Auth: authMethod,
|
Auth: authMethod,
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
BannerCallback: ssh.BannerDisplayStderr(),
|
BannerCallback: ssh.BannerDisplayStderr(),
|
||||||
Timeout: time.Second * 10,
|
Timeout: time.Second * 10,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ssh.NewClient(c, chans, reqs), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
|
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
|
||||||
if strings.Index(to.Addr, ":") < 0 {
|
if strings.Index(to.Addr, ":") < 0 {
|
||||||
// use default ssh port 22
|
// use default ssh port 22
|
||||||
to.Addr = net.JoinHostPort(to.Addr, "22")
|
to.Addr = net.JoinHostPort(to.Addr, "22")
|
||||||
@@ -410,6 +433,14 @@ func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
conn.Close()
|
||||||
|
if client != nil {
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
bClient.Close()
|
||||||
|
}()
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if client != nil {
|
if client != nil {
|
||||||
@@ -495,7 +526,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
|||||||
}
|
}
|
||||||
sshClient.Close()
|
sshClient.Close()
|
||||||
}
|
}
|
||||||
sshClient, err = DialSshRemote(conf)
|
sshClient, err = DialSshRemote(ctx, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to dial remote ssh server: %v", err)
|
log.Errorf("failed to dial remote ssh server: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -505,11 +536,20 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer localListen.Close()
|
defer localListen.Close()
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
localListen.Close()
|
||||||
|
if sshClient != nil {
|
||||||
|
sshClient.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
localConn, err := localListen.Accept()
|
localConn, err := localListen.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to accept conn: %v", err)
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
log.Errorf("failed to accept conn: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
@@ -521,7 +561,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer remoteConn.Close()
|
defer remoteConn.Close()
|
||||||
copyStream(localConn, remoteConn)
|
copyStream(ctx, localConn, remoteConn)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -551,7 +591,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
|
|||||||
|
|
||||||
// pre-check network ip connect
|
// pre-check network ip connect
|
||||||
var cli *ssh.Client
|
var cli *ssh.Client
|
||||||
cli, err = DialSshRemote(conf)
|
cli, err = DialSshRemote(ctx, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user