hotfix: fix ssh and port-forward retry bug (#360)

Co-authored-by: fengcaiwen <fengcaiwen@bytedance.com>
This commit is contained in:
naison
2024-10-25 21:25:03 +08:00
committed by GitHub
parent 07292fcde5
commit aa881a589e
11 changed files with 221 additions and 76 deletions

View File

@@ -42,18 +42,18 @@ func TCPForwarder(s *stack.Stack) func(stack.TransportEndpointID, *stack.PacketB
remote, err := forwardChain.dial(context.Background())
if err != nil {
log.Errorf("[TUN-TCP] Failed to dial remote conn: %v", err)
log.Debugf("[TUN-TCP] Failed to dial remote conn: %v", err)
return
}
if err = WriteProxyInfo(remote, id); err != nil {
log.Errorf("[TUN-TCP] Failed to write proxy info: %v", err)
log.Debugf("[TUN-TCP] Failed to write proxy info: %v", err)
return
}
w := &waiter.Queue{}
endpoint, tErr := request.CreateEndpoint(w)
if tErr != nil {
log.Errorf("[TUN-TCP] Failed to create endpoint: %v", tErr)
log.Debugf("[TUN-TCP] Failed to create endpoint: %v", tErr)
return
}
conn := gonet.NewTCPConn(w, endpoint)
@@ -77,7 +77,7 @@ func TCPForwarder(s *stack.Stack) func(stack.TransportEndpointID, *stack.PacketB
}()
err = <-errChan
if err != nil && !errors.Is(err, io.EOF) {
log.Errorf("[TUN-TCP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
log.Debugf("[TUN-TCP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
}
}).HandlePacket
}

View File

@@ -84,7 +84,7 @@ func (h *gvisorTCPHandler) Handle(ctx context.Context, tcpConn net.Conn) {
}()
err = <-errChan
if err != nil && !errors.Is(err, io.EOF) {
log.Errorf("[TUN-TCP] Disconnect: %s >-<: %s: %v", tcpConn.LocalAddr(), remote.RemoteAddr(), err)
log.Debugf("[TUN-TCP] Disconnect: %s >-<: %s: %v", tcpConn.LocalAddr(), remote.RemoteAddr(), err)
}
}

View File

@@ -108,9 +108,9 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config
//defer pkt.DecRef()
config.LPool.Put(bytes[:])
endpoint.InjectInbound(protocol, pkt)
log.Debugf("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
log.Tracef("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
} else {
log.Debugf("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
log.Tracef("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read)
util.SafeWrite(in, NewDataElem(bytes[:], read, src, dst))
}
}

View File

@@ -26,13 +26,13 @@ func UDPForwarder(s *stack.Stack) func(id stack.TransportEndpointID, pkt *stack.
w := &waiter.Queue{}
endpoint, tErr := request.CreateEndpoint(w)
if tErr != nil {
log.Errorf("[TUN-UDP] Failed to create endpoint: %v", tErr)
log.Debugf("[TUN-UDP] Failed to create endpoint: %v", tErr)
return
}
node, err := ParseNode(GvisorUDPForwardAddr)
if err != nil {
log.Errorf("[TUN-UDP] Failed to parse gviosr udp forward addr %s: %v", GvisorUDPForwardAddr, err)
log.Debugf("[TUN-UDP] Failed to parse gviosr udp forward addr %s: %v", GvisorUDPForwardAddr, err)
return
}
node.Client = &Client{
@@ -44,16 +44,16 @@ func UDPForwarder(s *stack.Stack) func(id stack.TransportEndpointID, pkt *stack.
ctx := context.Background()
c, err := forwardChain.getConn(ctx)
if err != nil {
log.Errorf("[TUN-UDP] Failed to get conn: %v", err)
log.Debugf("[TUN-UDP] Failed to get conn: %v", err)
return
}
if err = WriteProxyInfo(c, endpointID); err != nil {
log.Errorf("[TUN-UDP] Failed to write proxy info: %v", err)
log.Debugf("[TUN-UDP] Failed to write proxy info: %v", err)
return
}
remote, err := node.Client.ConnectContext(ctx, c)
if err != nil {
log.Errorf("[TUN-UDP] Failed to connect: %v", err)
log.Debugf("[TUN-UDP] Failed to connect: %v", err)
return
}
conn := gonet.NewUDPConn(w, endpoint)
@@ -77,7 +77,7 @@ func UDPForwarder(s *stack.Stack) func(id stack.TransportEndpointID, pkt *stack.
}()
err = <-errChan
if err != nil && !errors.Is(err, io.EOF) {
log.Errorf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
log.Debugf("[TUN-UDP] Disconnect: %s >-<: %s: %v", conn.LocalAddr(), remote.RemoteAddr(), err)
}
}()
}).HandlePacket

View File

@@ -50,7 +50,7 @@ func (w *wsHandler) handle(c context.Context) {
ctx, f := context.WithCancel(c)
defer f()
cli, err := pkgssh.DialSshRemote(ctx, w.sshConfig)
cli, err := pkgssh.DialSshRemote(ctx, w.sshConfig, ctx.Done())
if err != nil {
w.Log("Dial ssh remote error: %v", err)
return

View File

@@ -154,7 +154,7 @@ func (c *Config) watchServiceToAddHosts(ctx context.Context, serviceInterface v1
return
}
if err != nil && !errors.Is(err, context.Canceled) {
log.Error(err)
log.Debugf("Failed to watch service to add route table: %v", err)
}
if utilnet.IsConnectionRefused(err) || apierrors.IsTooManyRequests(err) || apierrors.IsForbidden(err) {
time.Sleep(time.Second * 1)

View File

@@ -20,6 +20,7 @@ import (
"github.com/containernetworking/cni/pkg/types"
"github.com/distribution/reference"
goversion "github.com/hashicorp/go-version"
"github.com/libp2p/go-netroute"
miekgdns "github.com/miekg/dns"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
@@ -31,6 +32,7 @@ import (
pkgruntime "k8s.io/apimachinery/pkg/runtime"
pkgtypes "k8s.io/apimachinery/pkg/types"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/cli-runtime/pkg/resource"
@@ -246,36 +248,47 @@ func (c *ConnectOptions) DoConnect(ctx context.Context, isLite bool) (err error)
// detect pod is delete event, if pod is deleted, needs to redo port-forward immediately
func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) error {
var readyChan = make(chan struct{}, 1)
firstCtx, firstCancelFunc := context.WithCancel(ctx)
defer firstCancelFunc()
var errChan = make(chan error, 1)
podInterface := c.clientset.CoreV1().Pods(c.Namespace)
var out = log.StandardLogger().WriterLevel(log.DebugLevel)
go func() {
defer out.Close()
runtime.ErrorHandlers = []func(error){}
var first = pointer.Bool(true)
for c.ctx.Err() == nil {
for ctx.Err() == nil {
func() {
defer time.Sleep(time.Millisecond * 200)
sortBy := func(pods []*v1.Pod) sort.Interface { return sort.Reverse(podutils.ActivePods(pods)) }
label := fields.OneTermEqualSelector("app", config.ConfigMapPodTrafficManager).String()
_, _, _ = polymorphichelpers.GetFirstPod(c.clientset.CoreV1(), c.Namespace, label, time.Second*5, sortBy)
podList, err := c.GetRunningPodList(ctx)
_, _, _ = polymorphichelpers.GetFirstPod(c.clientset.CoreV1(), c.Namespace, label, time.Second*10, sortBy)
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc2()
podList, err := c.GetRunningPodList(ctx2)
if err != nil {
log.Errorf("Failed to get running pod: %v", err)
log.Debugf("Failed to get running pod: %v", err)
if *first {
errChan <- err
util.SafeWrite(errChan, err)
}
return
}
childCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
if !*first {
readyChan = nil
}
var readyChan = make(chan struct{})
podName := podList[0].GetName()
// try to detect pod is delete event, if pod is deleted, needs to redo port-forward
go util.CheckPodStatus(childCtx, cancelFunc, podName, podInterface)
//go util.CheckPodStatus(childCtx, cancelFunc, podName, c.clientset.CoreV1().Pods(c.Namespace))
go util.CheckPortStatus(childCtx, cancelFunc, readyChan, strings.Split(portPair[1], ":")[0])
if *first {
go func() {
select {
case <-readyChan:
firstCancelFunc()
case <-childCtx.Done():
}
}()
}
var out = log.StandardLogger().WriterLevel(log.DebugLevel)
defer out.Close()
err = util.PortForwardPod(
c.config,
c.restclient,
@@ -288,19 +301,19 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
out,
)
if *first {
errChan <- err
util.SafeWrite(errChan, err)
}
first = pointer.Bool(false)
// exit normal, let context.err to judge to exit or not
if err == nil {
log.Errorf("Port forward retrying")
log.Debugf("Port forward retrying")
return
}
if strings.Contains(err.Error(), "unable to listen on any of the requested ports") ||
strings.Contains(err.Error(), "address already in use") {
log.Errorf("Port %s already in use, needs to release it manually", portPair)
log.Debugf("Port %s already in use, needs to release it manually", portPair)
} else {
log.Errorf("Port-forward occurs error: %v", err)
log.Debugf("Port-forward occurs error: %v", err)
}
}()
}
@@ -312,7 +325,7 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err
return errors.New("wait port forward to be ready timeout")
case err := <-errChan:
return err
case <-readyChan:
case <-firstCtx.Done():
return nil
}
}
@@ -433,6 +446,12 @@ func (c *ConnectOptions) addRouteDynamic(ctx context.Context) error {
} else {
mask = net.CIDRMask(128, 128)
}
if r, err := netroute.New(); err == nil {
iface, _, _, err := r.Route(ip)
if err == nil && iface.Name == tunName {
return
}
}
errs := tun.AddRoutes(tunName, types.Route{Dst: net.IPNet{IP: ip, Mask: mask}})
if errs != nil {
log.Errorf("Failed to add route, resource: %s, IP: %s, err: %v", resource, ip, errs)

View File

@@ -112,7 +112,7 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge
// transfer image to remote
var sshClient *ssh.Client
sshClient, err = DialSshRemote(ctx, conf)
sshClient, err = DialSshRemote(ctx, conf, ctx.Done())
if err != nil {
return err
}

View File

@@ -116,7 +116,7 @@ func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) {
}
// DialSshRemote https://github.com/golang/go/issues/21478
func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, err error) {
func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{}) (remote *ssh.Client, err error) {
defer func() {
if err != nil {
if remote != nil {
@@ -126,11 +126,11 @@ func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, er
}()
if conf.ConfigAlias != "" {
remote, err = conf.AliasRecursion(ctx)
remote, err = conf.AliasRecursion(ctx, stopChan)
} else if conf.Jump != "" {
remote, err = conf.JumpRecursion(ctx)
remote, err = conf.JumpRecursion(ctx, stopChan)
} else {
remote, err = conf.Dial(ctx)
remote, err = conf.Dial(ctx, stopChan)
}
// ref: https://github.com/golang/go/issues/21478
@@ -287,7 +287,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
}
}
func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client, err error) {
func (config SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
var name = config.ConfigAlias
var jumper = "ProxyJump"
var bastionList = []SshConfig{GetBastion(name, config)}
@@ -302,12 +302,12 @@ func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client,
}
for i := len(bastionList) - 1; i >= 0; i-- {
if client == nil {
client, err = bastionList[i].Dial(ctx)
client, err = bastionList[i].Dial(ctx, stopChan)
if err != nil {
return
}
} else {
client, err = JumpTo(ctx, client, bastionList[i])
client, err = JumpTo(ctx, client, bastionList[i], stopChan)
if err != nil {
return
}
@@ -316,7 +316,7 @@ func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client,
return
}
func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client, err error) {
func (config SshConfig) JumpRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
var sshConf = &SshConfig{}
AddSshFlags(flags, sshConf)
@@ -325,7 +325,7 @@ func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client,
return nil, err
}
var baseClient *ssh.Client
baseClient, err = DialSshRemote(ctx, sshConf)
baseClient, err = DialSshRemote(ctx, sshConf, stopChan)
if err != nil {
return nil, err
}
@@ -350,7 +350,7 @@ func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client,
}
for _, sshConfig := range bastionList {
client, err = JumpTo(ctx, baseClient, sshConfig)
client, err = JumpTo(ctx, baseClient, sshConfig, stopChan)
if err != nil {
return
}
@@ -393,7 +393,7 @@ func GetBastion(name string, defaultValue SshConfig) SshConfig {
return config
}
func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error) {
func (config SshConfig) Dial(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
if _, _, err = net.SplitHostPort(config.Addr); err != nil {
// use default ssh port 22
config.Addr = net.JoinHostPort(config.Addr, "22")
@@ -404,15 +404,28 @@ func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error
if err != nil {
return nil, err
}
conn, err := net.DialTimeout("tcp", config.Addr, time.Second*10)
d := net.Dialer{Timeout: time.Second * 10}
conn, err := d.DialContext(ctx, "tcp", config.Addr)
if err != nil {
return nil, err
}
go func() {
<-ctx.Done()
conn.Close()
if client != nil {
client.Close()
if stopChan != nil {
<-stopChan
conn.Close()
if client != nil {
client.Close()
}
}
}()
defer func() {
if err != nil {
if conn != nil {
conn.Close()
}
if client != nil {
client.Close()
}
}
}()
c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{
@@ -428,7 +441,7 @@ func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error
return ssh.NewClient(c, chans, reqs), nil
}
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) {
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig, stopChan <-chan struct{}) (client *ssh.Client, err error) {
if _, _, err = net.SplitHostPort(to.Addr); err != nil {
// use default ssh port 22
to.Addr = net.JoinHostPort(to.Addr, "22")
@@ -442,17 +455,19 @@ func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh
}
// Dial a connection to the service host, from the bastion
var conn net.Conn
conn, err = bClient.Dial("tcp", to.Addr)
conn, err = bClient.DialContext(ctx, "tcp", to.Addr)
if err != nil {
return
}
go func() {
<-ctx.Done()
conn.Close()
if client != nil {
client.Close()
if stopChan != nil {
<-stopChan
conn.Close()
if client != nil {
client.Close()
}
bClient.Close()
}
bClient.Close()
}()
defer func() {
if err != nil {
@@ -517,6 +532,16 @@ func init() {
})
}
type sshClient struct {
cancel context.CancelFunc
*ssh.Client
}
func (c *sshClient) Close() error {
c.cancel()
return c.Client.Close()
}
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
// Listen on remote server port
var lc net.ListenConfig
@@ -529,6 +554,66 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
go func() {
defer localListen.Close()
var sshClientChan = make(chan *sshClient, 1000*1000)
var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) {
select {
case cli, ok := <-sshClientChan:
if !ok {
return nil, errors.New("ssh client chan closed")
}
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc1()
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
if err != nil {
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
cli.Close()
return nil, err
}
write := pkgutil.SafeWrite(sshClientChan, cli)
if !write {
go func() {
<-connCtx.Done()
cli.Close()
}()
}
return conn, nil
default:
ctx1, cancelFunc1 := context.WithCancel(ctx)
defer func() {
if err != nil {
cancelFunc1()
}
}()
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc2()
var client *ssh.Client
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
if err != nil {
marshal, _ := json.Marshal(conf)
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
return nil, err
}
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc3()
conn, err = client.DialContext(ctx3, "tcp", remote.String())
if err != nil {
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
client.Close()
return nil, err
}
cli := &sshClient{cancel: cancelFunc1, Client: client}
write := pkgutil.SafeWrite(sshClientChan, cli)
if !write {
go func() {
<-connCtx.Done()
cli.Close()
}()
}
return conn, nil
}
}
for ctx.Err() == nil {
localConn, err1 := localListen.Accept()
if err1 != nil {
@@ -540,18 +625,19 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
cCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
sshClient, err := DialSshRemote(cCtx, conf)
var remoteConn net.Conn
var err error
for i := 0; i < 5; i++ {
remoteConn, err = getRemoteConnFunc(cCtx)
if err == nil {
break
}
}
if err != nil {
marshal, _ := json.Marshal(conf)
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
return
}
defer sshClient.Close()
remoteConn, err := sshClient.DialContext(cCtx, "tcp", remote.String())
if err != nil {
log.Debugf("Failed to dial %s: %s", remote.String(), err)
log.Debugf("Failed to get remote conn: %v", err)
return
}
defer remoteConn.Close()
copyStream(cCtx, localConn, remoteConn)
}()
@@ -584,7 +670,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
// pre-check network ip connect
var cli *ssh.Client
cli, err = DialSshRemote(ctx, conf)
cli, err = DialSshRemote(ctx, conf, ctx.Done())
if err != nil {
return
}

View File

@@ -9,14 +9,16 @@ func SafeRead[T any](c chan T) (T, bool) {
return tt, ok
}
func SafeWrite[T any](c chan<- T, value T) {
func SafeWrite[T any](c chan<- T, value T) bool {
defer func() {
if r := recover(); r != nil {
}
}()
select {
case c <- value:
return true
default:
return false
}
}

View File

@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
@@ -24,6 +25,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/cli-runtime/pkg/genericiooptions"
"k8s.io/cli-runtime/pkg/resource"
@@ -32,6 +34,7 @@ import (
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
"k8s.io/client-go/util/retry"
"k8s.io/kubectl/pkg/cmd/exec"
"k8s.io/kubectl/pkg/cmd/util"
"k8s.io/kubectl/pkg/polymorphichelpers"
@@ -169,7 +172,7 @@ func PortForwardPod(config *rest.Config, clientset *rest.RESTClient, podName, na
}
if err = forwarder.ForwardPorts(); err != nil {
log.Errorf("Forward port error: %s", err.Error())
log.Debugf("Forward port error: %s", err.Error())
return err
}
return nil
@@ -336,8 +339,7 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
})
if err != nil {
if !k8serrors.IsForbidden(err) && !errors.Is(err, context.Canceled) {
log.Errorf("Failed to watch Pod %s: %v", podName, err)
cancelFunc()
log.Debugf("Failed to watch Pod %s: %v", podName, err)
}
return
}
@@ -346,8 +348,7 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
_, err = podInterface.Get(ctx, podName, v1.GetOptions{})
if err != nil {
if !k8serrors.IsForbidden(err) && !errors.Is(err, context.Canceled) {
log.Errorf("Failed to get Pod %s: %v", podName, err)
cancelFunc()
log.Debugf("Failed to get Pod %s: %v", podName, err)
}
return
}
@@ -356,20 +357,20 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
if !ok {
_, err = podInterface.Get(ctx, podName, v1.GetOptions{})
if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("Failed to get Pod %s: %v", podName, err)
log.Debugf("Failed to get Pod %s: %v", podName, err)
cancelFunc()
}
return
}
switch e.Type {
case watch.Deleted:
log.Errorf("Pod %s is deleted", podName)
log.Debugf("Pod %s is deleted", podName)
cancelFunc()
return
case watch.Error:
_, err = podInterface.Get(ctx, podName, v1.GetOptions{})
if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("Failed to get Pod %s: %v", podName, err)
log.Debugf("Failed to get Pod %s: %v", podName, err)
cancelFunc()
}
return
@@ -381,6 +382,43 @@ func CheckPodStatus(ctx context.Context, cancelFunc context.CancelFunc, podName
}
}
func CheckPortStatus(ctx context.Context, cancelFunc context.CancelFunc, readyChan chan struct{}, localGvisorTCPPort string) {
defer cancelFunc()
ticker := time.NewTicker(time.Second * 60)
defer ticker.Stop()
select {
case <-readyChan:
case <-ticker.C:
log.Debugf("Wait port-forward to be ready timeout")
return
case <-ctx.Done():
return
}
for ctx.Err() == nil {
err := retry.OnError(wait.Backoff{
Steps: 6,
Duration: time.Second,
}, func(err error) bool {
return err != nil
}, func() error {
var lc net.ListenConfig
conn, err := lc.Listen(ctx, "tcp", net.JoinHostPort("127.0.0.1", localGvisorTCPPort))
if err == nil {
_ = conn.Close()
return errors.New("port is free")
}
return nil
})
if err != nil {
log.Debugf("Can not dial local port: %s: %v", localGvisorTCPPort, err)
return
}
time.Sleep(time.Second * 5)
}
}
func Rollback(f util.Factory, ns, workload string) {
r := f.NewBuilder().
WithScheme(scheme2.Scheme, scheme2.Scheme.PrioritizedVersionsAllGroups()...).