fix: fix some bug and optimize ssh logic (#147)

This commit is contained in:
naison
2024-02-07 20:36:10 +08:00
committed by GitHub
parent 59abb16136
commit 4abc5f004a
9 changed files with 186 additions and 216 deletions

View File

@@ -67,31 +67,7 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
}
}
func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, conf *SshConfig, done chan struct{}) error {
ctx, cancelFunc := context.WithCancel(pctx)
defer cancelFunc()
sshClient, err := DialSshRemote(conf)
if err != nil {
log.Errorf("Dial into remote server error: %s", err)
return err
}
// ref: https://github.com/golang/go/issues/21478
go func() {
ticker := time.NewTicker(time.Second * 5)
defer ticker.Stop()
select {
case <-ticker.C:
_, _, err := sshClient.SendRequest("keepalive@golang.org", true, nil)
if err != nil {
log.Errorf("failed to send keep alive error: %s", err)
}
case <-ctx.Done():
return
}
}()
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error {
// Listen on remote server port
var lc net.ListenConfig
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
@@ -105,34 +81,27 @@ func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, co
default:
}
// handle incoming connections on reverse forwarded tunnel
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
local, err := listen.Accept()
for ctx.Err() == nil {
localConn, err := listen.Accept()
if err != nil {
return err
}
go func(local net.Conn) {
defer local.Close()
conn, err := sshClient.Dial("tcp", remoteEndpoint.String())
go func(localConn net.Conn) {
defer localConn.Close()
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
if err != nil {
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
cancelFunc()
return
}
defer conn.Close()
handleClient(local, conn)
}(local)
defer remoteConn.Close()
copyStream(localConn, remoteConn)
}(localConn)
}
return ctx.Err()
}
// todo ssh heartbeats
// https://github.com/golang/go/issues/21478
func DialSshRemote(conf *SshConfig) (*ssh.Client, error) {
// DialSshRemote https://github.com/golang/go/issues/21478
func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
var remote *ssh.Client
var err error
if conf.ConfigAlias != "" {
@@ -189,22 +158,33 @@ func DialSshRemote(conf *SshConfig) (*ssh.Client, error) {
// Connect to SSH remote server using serverEndpoint
remote, err = ssh.Dial("tcp", conf.Addr, sshConfig)
}
// ref: https://github.com/golang/go/issues/21478
if err == nil {
go func() {
ticker := time.NewTicker(time.Second * 5)
defer ticker.Stop()
for ctx.Err() == nil {
select {
case <-ticker.C:
_, _, er := remote.SendRequest("keepalive@golang.org", true, nil)
if er != nil {
log.Errorf("failed to send keep alive error: %s", er)
}
}
}
}()
}
return remote, err
}
func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
var remote *ssh.Client
remote, err = DialSshRemote(conf)
if err != nil {
log.Errorf("Dial into remote server error: %s", err)
return
}
defer remote.Close()
func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
var session *ssh.Session
session, err = remote.NewSession()
session, err = client.NewSession()
if err != nil {
return
}
defer session.Close()
for k, v := range env {
// /etc/ssh/sshd_config
// AcceptEnv DEBIAN_FRONTEND
@@ -213,7 +193,6 @@ func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byt
err = nil
}
}
defer remote.Close()
var out bytes.Buffer
var er bytes.Buffer
session.Stdout = &out
@@ -246,7 +225,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
return ssh.PublicKeys(key), nil
}
func handleClient(client net.Conn, remote net.Conn) {
func copyStream(client net.Conn, remote net.Conn) {
chDone := make(chan bool, 2)
// start remote -> local data transfer
@@ -408,3 +387,32 @@ func init() {
}
})
}
func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error {
errChan := make(chan error, 1)
readyChan := make(chan struct{}, 1)
go func() {
for ctx.Err() == nil {
err := PortMap(ctx, cli, remote, local, readyChan)
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Errorf("Ssh forward failed err: %v", err)
}
select {
case errChan <- err:
default:
}
}
time.Sleep(time.Second * 2)
}
}()
select {
case <-readyChan:
return nil
case err := <-errChan:
log.Errorf("Ssh forward failed err: %v", err)
return err
case <-ctx.Done():
return ctx.Err()
}
}