mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-12-24 11:51:13 +08:00
hotfix: fix ssh and port-forward retry bug (#360)
Co-authored-by: fengcaiwen <fengcaiwen@bytedance.com>
This commit is contained in:
@@ -42,18 +42,18 @@ func TCPForwarder(s *stack.Stack) func(stack.TransportEndpointID, *stack.PacketB
|
||||
|
||||
remote, err := forwardChain.dial(context.Background())
|
||||
if err != nil {
|
||||
log.Errorf("[TUN-TCP] Failed to dial remote conn: %v", err)
|
||||
log.Debugf("[TUN-TCP] Failed to dial remote conn: %v", err)
|
||||
return
|
||||
}
|
||||
if err = WriteProxyInfo(remote, id); err != nil {
|
||||
log.Errorf("[TUN-TCP] Failed to write proxy info: %v", err)
|
||||
log.Debugf("[TUN-TCP] Failed to write proxy info: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
w := &waiter.Queue{}
|
||||
endpoint, tErr := request.CreateEndpoint(w)
|
||||
if tErr != nil {
|
||||
log.Errorf("[TUN-TCP] Failed to create endpoint: %v", tErr)
|
||||
log.Debugf("[TUN-TCP] Failed to create endpoint: %v", tErr)
|
||||
return
|
||||
}
|
||||
conn := gonet.NewTCPConn(w, endpoint)
|
||||
@@ -77,7 +77,7 @@ func TCPForwarder(s *stack.Stack) func(stack.TransportEndpointID, *stack.PacketB
|
||||
}()
|
||||
err = <-errChan
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Errorf("[TUN-TCP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
|
||||
log.Debugf("[TUN-TCP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
|
||||
}
|
||||
}).HandlePacket
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (h *gvisorTCPHandler) Handle(ctx context.Context, tcpConn net.Conn) {
|
||||
}()
|
||||
err = <-errChan
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Errorf("[TUN-TCP] Disconnect: %s >-<: %s: %v", tcpConn.LocalAddr(), remote.RemoteAddr(), err)
|
||||
log.Debugf("[TUN-TCP] Disconnect: %s >-<: %s: %v", tcpConn.LocalAddr(), remote.RemoteAddr(), err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -108,9 +108,9 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
|
||||
//defer pkt.DecRef()
|
||||
config.LPool.Put(bytes[:])
|
||||
endpoint.InjectInbound(protocol, pkt)
|
||||
log.Debugf("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||
log.Tracef("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||
} else {
|
||||
log.Debugf("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||
log.Tracef("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
|
||||
util.SafeWrite(in, NewDataElem(bytes[:], read, src, dst))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,13 +26,13 @@ func UDPForwarder(s *stack.Stack) func(id stack.TransportEndpointID, pkt *stack.
|
||||
w := &waiter.Queue{}
|
||||
endpoint, tErr := request.CreateEndpoint(w)
|
||||
if tErr != nil {
|
||||
log.Errorf("[TUN-UDP] Failed to create endpoint: %v", tErr)
|
||||
log.Debugf("[TUN-UDP] Failed to create endpoint: %v", tErr)
|
||||
return
|
||||
}
|
||||
|
||||
node, err := ParseNode(GvisorUDPForwardAddr)
|
||||
if err != nil {
|
||||
log.Errorf("[TUN-UDP] Failed to parse gviosr udp forward addr %s: %v", GvisorUDPForwardAddr, err)
|
||||
log.Debugf("[TUN-UDP] Failed to parse gviosr udp forward addr %s: %v", GvisorUDPForwardAddr, err)
|
||||
return
|
||||
}
|
||||
node.Client = &Client{
|
||||
@@ -44,16 +44,16 @@ func UDPForwarder(s *stack.Stack) func(id stack.TransportEndpointID, pkt *stack.
|
||||
ctx := context.Background()
|
||||
c, err := forwardChain.getConn(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("[TUN-UDP] Failed to get conn: %v", err)
|
||||
log.Debugf("[TUN-UDP] Failed to get conn: %v", err)
|
||||
return
|
||||
}
|
||||
if err = WriteProxyInfo(c, endpointID); err != nil {
|
||||
log.Errorf("[TUN-UDP] Failed to write proxy info: %v", err)
|
||||
log.Debugf("[TUN-UDP] Failed to write proxy info: %v", err)
|
||||
return
|
||||
}
|
||||
remote, err := node.Client.ConnectContext(ctx, c)
|
||||
if err != nil {
|
||||
log.Errorf("[TUN-UDP] Failed to connect: %v", err)
|
||||
log.Debugf("[TUN-UDP] Failed to connect: %v", err)
|
||||
return
|
||||
}
|
||||
conn := gonet.NewUDPConn(w, endpoint)
|
||||
@@ -77,7 +77,7 @@ func UDPForwarder(s *stack.Stack) func(id stack.TransportEndpointID, pkt *stack.
|
||||
}()
|
||||
err = <-errChan
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Errorf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
|
||||
log.Debugf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
|
||||
}
|
||||
}()
|
||||
}).HandlePacket
|
||||
|
||||
@@ -50,7 +50,7 @@ func (w *wsHandler) handle(c context.Context) {
|
||||
ctx, f := context.WithCancel(c)
|
||||
defer f()
|
||||
|
||||
cli, err := pkgssh.DialSshRemote(ctx, w.sshConfig)
|
||||
cli, err := pkgssh.DialSshRemote(ctx, w.sshConfig, ctx.Done())
|
||||
if err != nil {
|
||||
w.Log("Dial ssh remote error: %v", err)
|
||||
return
|
||||
|
||||
@@ -154,7 +154,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1
|
||||
return
|
||||
}
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Error(err)
|
||||
log.Debugf("Failed to watch service to add route table: %v", err)
|
||||
}
|
||||
if utilnet.IsConnectionRefused(err) || apierrors.IsTooManyRequests(err) || apierrors.IsForbidden(err) {
|
||||
time.Sleep(time.Second * 1)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/containernetworking/cni/pkg/types"
|
||||
"github.com/distribution/reference"
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
"github.com/libp2p/go-netroute"
|
||||
miekgdns "github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -31,6 +32,7 @@ import (
|
||||
pkgruntime "k8s.io/apimachinery/pkg/runtime"
|
||||
pkgtypes "k8s.io/apimachinery/pkg/types"
|
||||
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||
"k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
"k8s.io/cli-runtime/pkg/resource"
|
||||
@@ -246,36 +248,47 @@ func (c *ConnectOptions) DoConnect(ctx context.Context, isLite bool) (err error)
|
||||
|
||||
// detect pod is delete event, if pod is deleted, needs to redo port-forward immediately
|
||||
func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) error {
|
||||
var readyChan = make(chan struct{}, 1)
|
||||
firstCtx, firstCancelFunc := context.WithCancel(ctx)
|
||||
defer firstCancelFunc()
|
||||
var errChan = make(chan error, 1)
|
||||
podInterface := c.clientset.CoreV1().Pods(c.Namespace)
|
||||
var out = log.StandardLogger().WriterLevel(log.DebugLevel)
|
||||
go func() {
|
||||
defer out.Close()
|
||||
runtime.ErrorHandlers = []func(error){}
|
||||
var first = pointer.Bool(true)
|
||||
for c.ctx.Err() == nil {
|
||||
for ctx.Err() == nil {
|
||||
func() {
|
||||
defer time.Sleep(time.Millisecond * 200)
|
||||
|
||||
sortBy := func(pods []*v1.Pod) sort.Interface { return sort.Reverse(podutils.ActivePods(pods)) }
|
||||
label := fields.OneTermEqualSelector("app", config.ConfigMapPodTrafficManager).String()
|
||||
_, _, _ = polymorphichelpers.GetFirstPod(c.clientset.CoreV1(), c.Namespace, label, time.Second*5, sortBy)
|
||||
podList, err := c.GetRunningPodList(ctx)
|
||||
_, _, _ = polymorphichelpers.GetFirstPod(c.clientset.CoreV1(), c.Namespace, label, time.Second*10, sortBy)
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
podList, err := c.GetRunningPodList(ctx2)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get running pod: %v", err)
|
||||
log.Debugf("Failed to get running pod: %v", err)
|
||||
if *first {
|
||||
errChan <- err
|
||||
util.SafeWrite(errChan, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
childCtx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
if !*first {
|
||||
readyChan = nil
|
||||
}
|
||||
var readyChan = make(chan struct{})
|
||||
podName := podList[0].GetName()
|
||||
// try to detect pod is delete event, if pod is deleted, needs to redo port-forward
|
||||
go util.CheckPodStatus(childCtx, cancelFunc, podName, podInterface)
|
||||
//go util.CheckPodStatus(childCtx, cancelFunc, podName, c.clientset.CoreV1().Pods(c.Namespace))
|
||||
go util.CheckPortStatus(childCtx, cancelFunc, readyChan, strings.Split(portPair[1], ":")[0])
|
||||
if *first {
|
||||
go func() {
|
||||
select {
|
||||
case <-readyChan:
|
||||
firstCancelFunc()
|
||||
case <-childCtx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
var out = log.StandardLogger().WriterLevel(log.DebugLevel)
|
||||
defer out.Close()
|
||||
err = util.PortForwardPod(
|
||||
c.config,
|
||||
c.restclient,
|
||||
@@ -288,19 +301,19 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
|
||||
out,
|
||||
)
|
||||
if *first {
|
||||
errChan <- err
|
||||
util.SafeWrite(errChan, err)
|
||||
}
|
||||
first = pointer.Bool(false)
|
||||
// exit normal, let context.err to judge to exit or not
|
||||
if err == nil {
|
||||
log.Errorf("Port forward retrying")
|
||||
log.Debugf("Port forward retrying")
|
||||
return
|
||||
}
|
||||
if strings.Contains(err.Error(), "unable to listen on any of the requested ports") ||
|
||||
strings.Contains(err.Error(), "address already in use") {
|
||||
log.Errorf("Port %s already in use, needs to release it manually", portPair)
|
||||
log.Debugf("Port %s already in use, needs to release it manually", portPair)
|
||||
} else {
|
||||
log.Errorf("Port-forward occurs error: %v", err)
|
||||
log.Debugf("Port-forward occurs error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -312,7 +325,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
|
||||
return errors.New("wait port forward to be ready timeout")
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-readyChan:
|
||||
case <-firstCtx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -433,6 +446,12 @@ func (c *ConnectOptions) addRouteDynamic(ctx context.Context) error {
|
||||
} else {
|
||||
mask = net.CIDRMask(128, 128)
|
||||
}
|
||||
if r, err := netroute.New(); err == nil {
|
||||
iface, _, _, err := r.Route(ip)
|
||||
if err == nil && iface.Name == tunName {
|
||||
return
|
||||
}
|
||||
}
|
||||
errs := tun.AddRoutes(tunName, types.Route{Dst: net.IPNet{IP: ip, Mask: mask}})
|
||||
if errs != nil {
|
||||
log.Errorf("Failed to add route, resource: %s, IP: %s, err: %v", resource, ip, errs)
|
||||
|
||||
@@ -112,7 +112,7 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge
|
||||
|
||||
// transfer image to remote
|
||||
var sshClient *ssh.Client
|
||||
sshClient, err = DialSshRemote(ctx, conf)
|
||||
sshClient, err = DialSshRemote(ctx, conf, ctx.Done())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
152
pkg/ssh/ssh.go
152
pkg/ssh/ssh.go
@@ -116,7 +116,7 @@ func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) {
|
||||
}
|
||||
|
||||
// DialSshRemote https://github.com/golang/go/issues/21478
|
||||
func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, err error) {
|
||||
func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{}) (remote *ssh.Client, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if remote != nil {
|
||||
@@ -126,11 +126,11 @@ func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, er
|
||||
}()
|
||||
|
||||
if conf.ConfigAlias != "" {
|
||||
remote, err = conf.AliasRecursion(ctx)
|
||||
remote, err = conf.AliasRecursion(ctx, stopChan)
|
||||
} else if conf.Jump != "" {
|
||||
remote, err = conf.JumpRecursion(ctx)
|
||||
remote, err = conf.JumpRecursion(ctx, stopChan)
|
||||
} else {
|
||||
remote, err = conf.Dial(ctx)
|
||||
remote, err = conf.Dial(ctx, stopChan)
|
||||
}
|
||||
|
||||
// ref: https://github.com/golang/go/issues/21478
|
||||
@@ -287,7 +287,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client, err error) {
|
||||
func (config SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
var name = config.ConfigAlias
|
||||
var jumper = "ProxyJump"
|
||||
var bastionList = []SshConfig{GetBastion(name, config)}
|
||||
@@ -302,12 +302,12 @@ func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client,
|
||||
}
|
||||
for i := len(bastionList) - 1; i >= 0; i-- {
|
||||
if client == nil {
|
||||
client, err = bastionList[i].Dial(ctx)
|
||||
client, err = bastionList[i].Dial(ctx, stopChan)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
client, err = JumpTo(ctx, client, bastionList[i])
|
||||
client, err = JumpTo(ctx, client, bastionList[i], stopChan)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -316,7 +316,7 @@ func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client,
|
||||
return
|
||||
}
|
||||
|
||||
func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client, err error) {
|
||||
func (config SshConfig) JumpRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
||||
var sshConf = &SshConfig{}
|
||||
AddSshFlags(flags, sshConf)
|
||||
@@ -325,7 +325,7 @@ func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client,
|
||||
return nil, err
|
||||
}
|
||||
var baseClient *ssh.Client
|
||||
baseClient, err = DialSshRemote(ctx, sshConf)
|
||||
baseClient, err = DialSshRemote(ctx, sshConf, stopChan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client,
|
||||
}
|
||||
|
||||
for _, sshConfig := range bastionList {
|
||||
client, err = JumpTo(ctx, baseClient, sshConfig)
|
||||
client, err = JumpTo(ctx, baseClient, sshConfig, stopChan)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -393,7 +393,7 @@ func GetBastion(name string, defaultValue SshConfig) SshConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error) {
|
||||
func (config SshConfig) Dial(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
if _, _, err = net.SplitHostPort(config.Addr); err != nil {
|
||||
// use default ssh port 22
|
||||
config.Addr = net.JoinHostPort(config.Addr, "22")
|
||||
@@ -404,15 +404,28 @@ func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.DialTimeout("tcp", config.Addr, time.Second*10)
|
||||
d := net.Dialer{Timeout: time.Second * 10}
|
||||
conn, err := d.DialContext(ctx, "tcp", config.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
if stopChan != nil {
|
||||
<-stopChan
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{
|
||||
@@ -428,7 +441,7 @@ func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error
|
||||
return ssh.NewClient(c, chans, reqs), nil
|
||||
}
|
||||
|
||||
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
|
||||
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
if _, _, err = net.SplitHostPort(to.Addr); err != nil {
|
||||
// use default ssh port 22
|
||||
to.Addr = net.JoinHostPort(to.Addr, "22")
|
||||
@@ -442,17 +455,19 @@ func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh
|
||||
}
|
||||
// Dial a connection to the service host, from the bastion
|
||||
var conn net.Conn
|
||||
conn, err = bClient.Dial("tcp", to.Addr)
|
||||
conn, err = bClient.DialContext(ctx, "tcp", to.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
if stopChan != nil {
|
||||
<-stopChan
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
bClient.Close()
|
||||
}
|
||||
bClient.Close()
|
||||
}()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -517,6 +532,16 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
type sshClient struct {
|
||||
cancel context.CancelFunc
|
||||
*ssh.Client
|
||||
}
|
||||
|
||||
func (c *sshClient) Close() error {
|
||||
c.cancel()
|
||||
return c.Client.Close()
|
||||
}
|
||||
|
||||
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
|
||||
// Listen on remote server port
|
||||
var lc net.ListenConfig
|
||||
@@ -529,6 +554,66 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
go func() {
|
||||
defer localListen.Close()
|
||||
|
||||
var sshClientChan = make(chan *sshClient, 1000*1000)
|
||||
|
||||
var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) {
|
||||
select {
|
||||
case cli, ok := <-sshClientChan:
|
||||
if !ok {
|
||||
return nil, errors.New("ssh client chan closed")
|
||||
}
|
||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc1()
|
||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||
cli.Close()
|
||||
return nil, err
|
||||
}
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-connCtx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
return conn, nil
|
||||
default:
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc1()
|
||||
}
|
||||
}()
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
var client *ssh.Client
|
||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||
if err != nil {
|
||||
marshal, _ := json.Marshal(conf)
|
||||
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
|
||||
return nil, err
|
||||
}
|
||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc3()
|
||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
cli := &sshClient{cancel: cancelFunc1, Client: client}
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-connCtx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
localConn, err1 := localListen.Accept()
|
||||
if err1 != nil {
|
||||
@@ -540,18 +625,19 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
cCtx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
sshClient, err := DialSshRemote(cCtx, conf)
|
||||
var remoteConn net.Conn
|
||||
var err error
|
||||
for i := 0; i < 5; i++ {
|
||||
remoteConn, err = getRemoteConnFunc(cCtx)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
marshal, _ := json.Marshal(conf)
|
||||
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
|
||||
return
|
||||
}
|
||||
defer sshClient.Close()
|
||||
remoteConn, err := sshClient.DialContext(cCtx, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial %s: %s", remote.String(), err)
|
||||
log.Debugf("Failed to get remote conn: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer remoteConn.Close()
|
||||
copyStream(cCtx, localConn, remoteConn)
|
||||
}()
|
||||
@@ -584,7 +670,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
|
||||
|
||||
// pre-check network ip connect
|
||||
var cli *ssh.Client
|
||||
cli, err = DialSshRemote(ctx, conf)
|
||||
cli, err = DialSshRemote(ctx, conf, ctx.Done())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,14 +9,16 @@ func SafeRead[T any](c chan T) (T, bool) {
|
||||
return tt, ok
|
||||
}
|
||||
|
||||
func SafeWrite[T any](c chan<- T, value T) {
|
||||
func SafeWrite[T any](c chan<- T, value T) bool {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case c <- value:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
"k8s.io/apimachinery/pkg/watch"
|
||||
"k8s.io/cli-runtime/pkg/genericiooptions"
|
||||
"k8s.io/cli-runtime/pkg/resource"
|
||||
@@ -32,6 +34,7 @@ import (
|
||||
"k8s.io/client-go/rest"
|
||||
"k8s.io/client-go/tools/portforward"
|
||||
"k8s.io/client-go/transport/spdy"
|
||||
"k8s.io/client-go/util/retry"
|
||||
"k8s.io/kubectl/pkg/cmd/exec"
|
||||
"k8s.io/kubectl/pkg/cmd/util"
|
||||
"k8s.io/kubectl/pkg/polymorphichelpers"
|
||||
@@ -169,7 +172,7 @@ func PortForwardPod(config *rest.Config, clientset *rest.RESTClient, podName, na
|
||||
}
|
||||
|
||||
if err = forwarder.ForwardPorts(); err != nil {
|
||||
log.Errorf("Forward port error: %s", err.Error())
|
||||
log.Debugf("Forward port error: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -336,8 +339,7 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
|
||||
})
|
||||
if err != nil {
|
||||
if !k8serrors.IsForbidden(err) && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Failed to watch Pod %s: %v", podName, err)
|
||||
cancelFunc()
|
||||
log.Debugf("Failed to watch Pod %s: %v", podName, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -346,8 +348,7 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
|
||||
_, err = podInterface.Get(ctx, podName, v1.GetOptions{})
|
||||
if err != nil {
|
||||
if !k8serrors.IsForbidden(err) && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Failed to get Pod %s: %v", podName, err)
|
||||
cancelFunc()
|
||||
log.Debugf("Failed to get Pod %s: %v", podName, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -356,20 +357,20 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
|
||||
if !ok {
|
||||
_, err = podInterface.Get(ctx, podName, v1.GetOptions{})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Failed to get Pod %s: %v", podName, err)
|
||||
log.Debugf("Failed to get Pod %s: %v", podName, err)
|
||||
cancelFunc()
|
||||
}
|
||||
return
|
||||
}
|
||||
switch e.Type {
|
||||
case watch.Deleted:
|
||||
log.Errorf("Pod %s is deleted", podName)
|
||||
log.Debugf("Pod %s is deleted", podName)
|
||||
cancelFunc()
|
||||
return
|
||||
case watch.Error:
|
||||
_, err = podInterface.Get(ctx, podName, v1.GetOptions{})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Failed to get Pod %s: %v", podName, err)
|
||||
log.Debugf("Failed to get Pod %s: %v", podName, err)
|
||||
cancelFunc()
|
||||
}
|
||||
return
|
||||
@@ -381,6 +382,43 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
|
||||
}
|
||||
}
|
||||
|
||||
func CheckPortStatus(ctx context.Context, cancelFunc context.CancelFunc, readyChan chan struct{}, localGvisorTCPPort string) {
|
||||
defer cancelFunc()
|
||||
ticker := time.NewTicker(time.Second * 60)
|
||||
defer ticker.Stop()
|
||||
|
||||
select {
|
||||
case <-readyChan:
|
||||
case <-ticker.C:
|
||||
log.Debugf("Wait port-forward to be ready timeout")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
err := retry.OnError(wait.Backoff{
|
||||
Steps: 6,
|
||||
Duration: time.Second,
|
||||
}, func(err error) bool {
|
||||
return err != nil
|
||||
}, func() error {
|
||||
var lc net.ListenConfig
|
||||
conn, err := lc.Listen(ctx, "tcp", net.JoinHostPort("127.0.0.1", localGvisorTCPPort))
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
return errors.New("port is free")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Debugf("Can not dial local port: %s: %v", localGvisorTCPPort, err)
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}
|
||||
|
||||
func Rollback(f util.Factory, ns, workload string) {
|
||||
r := f.NewBuilder().
|
||||
WithScheme(scheme2.Scheme, scheme2.Scheme.PrioritizedVersionsAllGroups()...).
|
||||
|
||||
Reference in New Issue
Block a user