ssh: Introduce 'retry' helper

initialConnection retries multiple times to establish the TCP connection
which will be used for ssh communication.
This commit adds a generic helper to handle the retry which will be
useful in the next commits.

Signed-off-by: Christophe Fergeau <cfergeau@redhat.com>
This commit is contained in:
Christophe Fergeau
2024-01-11 11:51:09 -05:00
parent f01fd1c0dd
commit aa3fa9a2bb

View File

@@ -2,6 +2,7 @@ package sshclient
import ( import (
"context" "context"
"fmt"
"io" "io"
"net" "net"
"net/url" "net/url"
@@ -180,16 +181,19 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity
return &SSHForward{listener, &bastion, socketURI}, nil return &SSHForward{listener, &bastion, socketURI}, nil
} }
func initialConnection(ctx context.Context, connectFunc ConnectCallback) (net.Conn, error) { const maxRetries = 60
const initialBackoff = 100 * time.Millisecond
func retry[T comparable](ctx context.Context, retryFunc func() (T, error), retryMsg string) (T, error) {
var ( var (
conn net.Conn returnVal T
err error err error
) )
backoff := 100 * time.Millisecond backoff := initialBackoff
loop: loop:
for i := 0; i < 60; i++ { for i := 0; i < maxRetries; i++ {
select { select {
case <-ctx.Done(): case <-ctx.Done():
break loop break loop
@@ -197,15 +201,22 @@ loop:
// proceed // proceed
} }
conn, err = connectFunc(ctx, nil) returnVal, err = retryFunc()
if err == nil { if err == nil {
break return returnVal, nil
} }
logrus.Debugf("Waiting for sshd: %s", backoff) logrus.Debugf("%s (%s)", retryMsg, backoff)
sleep(ctx, backoff) sleep(ctx, backoff)
backoff = backOff(backoff) backoff = backOff(backoff)
} }
return conn, err return returnVal, fmt.Errorf("timeout: %w", err)
}
func initialConnection(ctx context.Context, connectFunc ConnectCallback) (net.Conn, error) {
retryFunc := func() (net.Conn, error) {
return connectFunc(ctx, nil)
}
return retry(ctx, retryFunc, "Waiting for sshd socket")
} }
func acceptConnection(ctx context.Context, listener net.Listener, bastion *Bastion, socketURI *url.URL) error { func acceptConnection(ctx context.Context, listener net.Listener, bastion *Bastion, socketURI *url.URL) error {