Files
oneterm/backend/internal/tunneling/tunnel.go

277 lines
6.9 KiB
Go

package tunneling
import (
"context"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/spf13/cast"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/logger"
)
// GatewayTunnel represents a SSH tunnel through a gateway
type GatewayTunnel struct {
listener net.Listener
GatewayId int
SessionId string
LocalIp string
LocalPort int
RemoteIp string
RemotePort int
LocalConn net.Conn
RemoteConn net.Conn
Opened chan error
}
// Open opens the gateway tunnel
func (gt *GatewayTunnel) Open(sshClient *ssh.Client, isConnectable bool) (err error) {
go func() {
<-time.After(time.Second * 3)
logger.L().Debug("timeout 3 second close listener", zap.String("sessionId", gt.SessionId))
gt.listener.Close()
}()
defer func() {
logger.L().Debug("close listener", zap.String("sessionId", gt.SessionId), zap.Error(err))
gt.Opened <- err
}()
gt.Opened <- nil
gt.LocalConn, err = gt.listener.Accept()
if err != nil {
logger.L().Error("accept failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
return
}
remoteAddr := fmt.Sprintf("%s:%d", gt.RemoteIp, gt.RemotePort)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
gt.RemoteConn, err = sshClient.DialContext(ctx, "tcp", remoteAddr)
if err != nil {
defer func() {
if gt.LocalConn != nil {
defer gt.LocalConn.Close()
}
if gt.RemoteConn != nil {
defer gt.RemoteConn.Close()
}
}()
logger.L().Error("dial remote failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
return
}
if isConnectable {
return
}
go io.Copy(gt.LocalConn, gt.RemoteConn)
go io.Copy(gt.RemoteConn, gt.LocalConn)
return
}
// TunnelManager manages SSH tunnels through gateways
type TunnelManager struct {
gatewayTunnels map[string]*GatewayTunnel
sshClients map[int]*ssh.Client
sshClientsCount map[int]int
mtx sync.Mutex
}
// NewTunnelManager creates a new tunnel manager
func NewTunnelManager() *TunnelManager {
return &TunnelManager{
gatewayTunnels: map[string]*GatewayTunnel{},
sshClients: map[int]*ssh.Client{},
sshClientsCount: map[int]int{},
mtx: sync.Mutex{},
}
}
// GetTunnelBySessionId gets a gateway tunnel by session ID
func (tm *TunnelManager) GetTunnelBySessionId(sessionId string) *GatewayTunnel {
return tm.gatewayTunnels[sessionId]
}
// OpenTunnel opens a new gateway tunnel
func (tm *TunnelManager) OpenTunnel(isConnectable bool, sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (*GatewayTunnel, error) {
if gateway == nil {
return nil, fmt.Errorf("gateway is nil")
}
tm.mtx.Lock()
defer tm.mtx.Unlock()
sshCli, ok := tm.sshClients[gateway.Id]
if !ok {
auth, err := tm.getAuthMethod(gateway)
if err != nil {
return nil, err
}
sshCli, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", gateway.Host, gateway.Port), &ssh.ClientConfig{
User: gateway.Account,
Auth: []ssh.AuthMethod{auth},
Timeout: time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
logger.L().Error("open gateway sshcli failed", zap.Int("gatewayId", gateway.Id), zap.Error(err))
return nil, err
}
go func() {
logger.L().Debug("ssh proxy wait closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait()))
delete(tm.sshClients, gateway.Id)
}()
}
tm.sshClients[gateway.Id] = sshCli
tm.sshClientsCount[gateway.Id] += 1
localPort, err := getAvailablePort()
if err != nil {
return nil, err
}
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", localPort))
if err != nil {
return nil, err
}
g := &GatewayTunnel{
listener: listener,
GatewayId: gateway.Id,
SessionId: sessionId,
LocalIp: "localhost",
LocalPort: localPort,
RemoteIp: remoteIp,
RemotePort: remotePort,
Opened: make(chan error),
}
tm.gatewayTunnels[sessionId] = g
go g.Open(sshCli, isConnectable)
logger.L().Debug("opening gateway", zap.Any("sessionId", sessionId))
<-g.Opened
logger.L().Debug("opened gateway", zap.Any("sessionId", sessionId))
return g, nil
}
// CloseTunnels closes gateway tunnels by session IDs
func (tm *TunnelManager) CloseTunnels(sessionIds ...string) {
tm.mtx.Lock()
defer tm.mtx.Unlock()
for _, sid := range sessionIds {
gt, ok := tm.gatewayTunnels[sid]
if !ok {
continue
}
tm.sshClientsCount[gt.GatewayId] -= 1
if tm.sshClientsCount[gt.GatewayId] <= 0 {
if g := tm.sshClients[gt.GatewayId]; g != nil {
g.Close()
}
delete(tm.sshClients, gt.GatewayId)
delete(tm.sshClientsCount, gt.GatewayId)
}
// Close and delete tunnel
if gt.listener != nil {
gt.listener.Close()
}
if gt.LocalConn != nil {
gt.LocalConn.Close()
}
if gt.RemoteConn != nil {
gt.RemoteConn.Close()
}
delete(tm.gatewayTunnels, sid)
}
}
// getAuthMethod gets SSH authentication method based on gateway config
func (tm *TunnelManager) getAuthMethod(gateway *model.Gateway) (ssh.AuthMethod, error) {
switch gateway.AccountType {
case model.AUTHMETHOD_PASSWORD:
return ssh.Password(gateway.Password), nil
case model.AUTHMETHOD_PUBLICKEY:
if gateway.Phrase == "" {
pk, err := ssh.ParsePrivateKey([]byte(gateway.Pk))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
} else {
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(gateway.Pk), []byte(gateway.Phrase))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
}
default:
return nil, fmt.Errorf("invalid authmethod %d", gateway.AccountType)
}
}
// getAvailablePort gets an available local port
func getAvailablePort() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return 0, err
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return 0, err
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
// Proxy establishes a proxy connection to an asset through a gateway if necessary
func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) {
// Handle case 1: asset.Ip already contains port (e.g., "127.0.0.1:8000")
if strings.Contains(asset.Ip, ":") {
ipParts := strings.Split(asset.Ip, ":")
if len(ipParts) >= 2 {
ip = ipParts[0]
port = cast.ToInt(ipParts[1])
} else {
ip = asset.Ip
port = 0
}
} else {
// Case 2: asset.Ip without port (e.g., "127.0.0.1"), extract port from protocol
ip, port = asset.Ip, 0
for _, tp := range strings.Split(protocol, ",") {
for _, p := range asset.Protocols {
if strings.HasPrefix(strings.ToLower(p), tp) {
parts := strings.Split(p, ":")
if len(parts) >= 2 {
if port = cast.ToInt(parts[1]); port != 0 {
break
}
}
}
}
}
}
if asset.GatewayId == 0 || gateway == nil {
return
}
g, err := OpenTunnel(isConnectable, sessionId, ip, port, gateway)
if err != nil {
return
}
ip, port = g.LocalIp, g.LocalPort
return
}