hotfix: fix ssh re-connect logic

This commit is contained in:
naison
2024-10-10 02:28:30 +00:00
parent baf5b79a24
commit 15103837a7
2 changed files with 13 additions and 36 deletions

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
@@ -39,13 +38,12 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) {
for ctx.Err() == nil {
packetConn, err := getRemotePacketConn(ctx, h.chain)
if err != nil {
log.Errorf("[TUN-CLIENT] Failed to get remote conn from %s -> %s: %s", tun.LocalAddr(), remoteAddr, err)
time.Sleep(time.Second * 2)
log.Debugf("[TUN-CLIENT] Failed to get remote conn from %s -> %s: %s", tun.LocalAddr(), remoteAddr, err)
continue
}
err = transportTunClient(ctx, tunInbound, tunOutbound, packetConn, remoteAddr)
if err != nil {
log.Errorf("[TUN-CLIENT] %s: %v", tun.LocalAddr(), err)
log.Debugf("[TUN-CLIENT] %s: %v", tun.LocalAddr(), err)
}
}
})

View File

@@ -247,7 +247,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
defer config.LPool.Put(buf[:])
_, err := io.CopyBuffer(local, remote, buf)
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
log.Errorf("Failed to copy remote -> local: %s", err)
log.Debugf("Failed to copy remote -> local: %s", err)
}
select {
case chDone <- true:
@@ -261,7 +261,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
defer config.LPool.Put(buf[:])
_, err := io.CopyBuffer(remote, local, buf)
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
log.Errorf("Failed to copy local -> remote: %s", err)
log.Debugf("Failed to copy local -> remote: %s", err)
}
select {
case chDone <- true:
@@ -513,39 +513,11 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
return err
}
var lock sync.Mutex
var sshClient *ssh.Client
var getRemoteConnFunc = func() (net.Conn, error) {
lock.Lock()
defer lock.Unlock()
if sshClient != nil {
ctx1, cancelFunc := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc()
remoteConn, err := sshClient.DialContext(ctx1, "tcp", remote.String())
if err == nil {
return remoteConn, nil
}
sshClient.Close()
sshClient = nil
}
sshClient, err = DialSshRemote(ctx, conf)
if err != nil {
log.Errorf("failed to dial remote ssh server: %v", err)
return nil, err
}
return sshClient.Dial("tcp", remote.String())
}
go func() {
defer localListen.Close()
go func() {
<-ctx.Done()
localListen.Close()
if sshClient != nil {
sshClient.Close()
}
}()
for ctx.Err() == nil {
@@ -559,9 +531,16 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
go func() {
defer localConn.Close()
remoteConn, err := getRemoteConnFunc()
sshClient, err := DialSshRemote(ctx, conf)
if err != nil {
log.Errorf("Failed to dial %s: %s", remote.String(), err)
marshal, _ := json.Marshal(conf)
log.Debugf("Failed to dial remote ssh server %v : %v", string(marshal), err)
return
}
defer sshClient.Close()
remoteConn, err := sshClient.DialContext(ctx, "tcp", remote.String())
if err != nil {
log.Debugf("Failed to dial %s: %s", remote.String(), err)
return
}
defer remoteConn.Close()