mirror of
https://github.com/containers/gvisor-tap-vsock.git
synced 2025-10-26 02:00:36 +08:00
Implement windows ssh proxy with windows pipe support
Signed-off-by: Jason T. Greene <jason.greene@redhat.com>
This commit is contained in:
246
pkg/sshclient/ssh_forwarder.go
Normal file
246
pkg/sshclient/ssh_forwarder.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package sshclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containers/gvisor-tap-vsock/pkg/fs"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type CloseWriteStream interface {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
type SSHForward struct {
|
||||
listener net.Listener
|
||||
bastion *Bastion
|
||||
sock *url.URL
|
||||
}
|
||||
|
||||
type SSHDialer interface {
|
||||
DialContextTCP(ctx context.Context, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type genericTCPDialer struct {
|
||||
}
|
||||
|
||||
var defaultTCPDialer genericTCPDialer
|
||||
|
||||
func (dialer *genericTCPDialer) DialContextTCP(ctx context.Context, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
|
||||
func CreateSSHForward(ctx context.Context, src *url.URL, dest *url.URL, identity string, dialer SSHDialer) (*SSHForward, error) {
|
||||
if dialer == nil {
|
||||
dialer = &defaultTCPDialer
|
||||
}
|
||||
|
||||
return setupProxy(ctx, src, dest, identity, dialer)
|
||||
}
|
||||
|
||||
func (forward *SSHForward) AcceptAndTunnel(ctx context.Context) error {
|
||||
return acceptConnection(ctx, forward.listener, forward.bastion, forward.sock)
|
||||
}
|
||||
|
||||
func (forward *SSHForward) Close() {
|
||||
if forward.listener != nil {
|
||||
forward.listener.Close()
|
||||
}
|
||||
if forward.bastion != nil {
|
||||
forward.bastion.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func connectForward(ctx context.Context, bastion *Bastion) (CloseWriteStream, error) {
|
||||
for retries := 1; ; retries++ {
|
||||
forward, err := bastion.Client.Dial("unix", bastion.Path)
|
||||
if err == nil {
|
||||
return forward.(ssh.Channel), nil
|
||||
}
|
||||
if retries > 2 {
|
||||
return nil, errors.Wrapf(err, "Couldn't restablish ssh tunnel on path: %s", bastion.Path)
|
||||
}
|
||||
// Check if ssh connection is still alive
|
||||
_, _, err = bastion.Client.Conn.SendRequest("alive@gvproxy", true, nil)
|
||||
if err != nil {
|
||||
for bastionRetries := 1; ; bastionRetries++ {
|
||||
err = bastion.Reconnect()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if bastionRetries > 2 || !sleep(ctx, 200*time.Millisecond) {
|
||||
return nil, errors.Wrapf(err, "Couldn't reestablish ssh connection: %s", bastion.Host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !sleep(ctx, 200*time.Millisecond) {
|
||||
retries = 3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func listenUnix(socketURI *url.URL) (net.Listener, error) {
|
||||
path := socketURI.Path
|
||||
if runtime.GOOS == "windows" {
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
}
|
||||
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldmask := fs.Umask(0177)
|
||||
defer fs.Umask(oldmask)
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return listener, errors.Wrapf(err, "Error listening on socket: %s", socketURI.Path)
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity string, dialer SSHDialer) (*SSHForward, error) {
|
||||
var (
|
||||
listener net.Listener
|
||||
err error
|
||||
)
|
||||
switch socketURI.Scheme {
|
||||
case "unix":
|
||||
listener, err = listenUnix(socketURI)
|
||||
if err != nil {
|
||||
return &SSHForward{}, err
|
||||
}
|
||||
case "npipe":
|
||||
listener, err = listenNpipe(socketURI)
|
||||
if err != nil {
|
||||
return &SSHForward{}, err
|
||||
}
|
||||
default:
|
||||
return &SSHForward{}, errors.Errorf("URI scheme not supported: %s", socketURI.Scheme)
|
||||
}
|
||||
|
||||
connectFunc := func(bastion *Bastion) (net.Conn, error) {
|
||||
timeout := 5 * time.Second
|
||||
if bastion != nil {
|
||||
timeout = bastion.Config.Timeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
conn, err := dialer.DialContextTCP(ctx, dest.Host)
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
conn, err := initialConnection(ctx, connectFunc)
|
||||
if err != nil {
|
||||
return &SSHForward{}, err
|
||||
}
|
||||
|
||||
bastion, err := CreateBastion(dest, "", identity, conn, connectFunc)
|
||||
if err != nil {
|
||||
return &SSHForward{}, err
|
||||
}
|
||||
|
||||
logrus.Infof("Socket forward established: %s to %s\n", socketURI.Path, dest.Path)
|
||||
|
||||
return &SSHForward{listener, &bastion, socketURI}, nil
|
||||
}
|
||||
|
||||
func initialConnection(ctx context.Context, connectFunc ConnectCallback) (net.Conn, error) {
|
||||
var (
|
||||
conn net.Conn
|
||||
err error
|
||||
)
|
||||
|
||||
backoff := 100 * time.Millisecond
|
||||
|
||||
loop:
|
||||
for i := 0; i < 60; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
default:
|
||||
// proceed
|
||||
}
|
||||
|
||||
conn, err = connectFunc(nil)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
logrus.Debugf("Waiting for sshd: %s", backoff)
|
||||
sleep(ctx, backoff)
|
||||
backoff = backOff(backoff)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func acceptConnection(ctx context.Context, listener net.Listener, bastion *Bastion, socketURI *url.URL) error {
|
||||
con, err := listener.Accept()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "Error accepting on socket: %s", socketURI.Path)
|
||||
}
|
||||
|
||||
src, ok := con.(CloseWriteStream)
|
||||
if !ok {
|
||||
con.Close()
|
||||
return errors.Wrapf(err, "Underlying socket does not support half-close %s", socketURI.Path)
|
||||
}
|
||||
|
||||
var dest CloseWriteStream
|
||||
|
||||
dest, err = connectForward(ctx, bastion)
|
||||
if err != nil {
|
||||
con.Close()
|
||||
logrus.Error(err)
|
||||
return nil // eat
|
||||
}
|
||||
|
||||
go forward(src, dest)
|
||||
go forward(dest, src)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func forward(src io.ReadCloser, dest CloseWriteStream) {
|
||||
defer src.Close()
|
||||
_, _ = io.Copy(dest, src)
|
||||
|
||||
// Trigger an EOF on the other end
|
||||
_ = dest.CloseWrite()
|
||||
}
|
||||
|
||||
func backOff(delay time.Duration) time.Duration {
|
||||
if delay == 0 {
|
||||
delay = 5 * time.Millisecond
|
||||
} else {
|
||||
delay *= 2
|
||||
}
|
||||
if delay > time.Second {
|
||||
delay = time.Second
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func sleep(ctx context.Context, wait time.Duration) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(wait):
|
||||
return true
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user