Files
oneterm/backend/gateway/gateway.go
2024-10-11 18:14:37 +08:00

212 lines
5.1 KiB
Go

package gateway
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
"github.com/veops/oneterm/logger"
"github.com/veops/oneterm/model"
)
var (
manager = &GateWayManager{
gatewayTunnels: map[string]*GatewayTunnel{},
sshClients: map[int]*ssh.Client{},
sshClientsCount: map[int]int{},
mtx: sync.Mutex{},
}
)
func GetGatewayManager() *GateWayManager {
return manager
}
func GetGatewayTunnelBySessionId(sessionId string) *GatewayTunnel {
return manager.gatewayTunnels[sessionId]
}
type GatewayTunnel struct {
listener net.Listener
GatewayId int
SessionId string
LocalIp string
LocalPort int
RemoteIp string
RemotePort int
LocalConn net.Conn
RemoteConn net.Conn
Opened chan error
}
func (gt *GatewayTunnel) Open(isConnectable bool) (err error) {
go func() {
<-time.After(time.Second * 3)
logger.L().Debug("timeout 3 second close listener", zap.String("sessionId", gt.SessionId))
gt.listener.Close()
}()
defer func() {
logger.L().Debug("close listener", zap.String("sessionId", gt.SessionId), zap.Error(err))
gt.Opened <- err
}()
gt.Opened <- nil
gt.LocalConn, err = gt.listener.Accept()
if err != nil {
logger.L().Error("accept failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
return
}
remoteAddr := fmt.Sprintf("%s:%d", gt.RemoteIp, gt.RemotePort)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
gt.RemoteConn, err = manager.sshClients[gt.GatewayId].DialContext(ctx, "tcp", remoteAddr)
if err != nil {
defer func() {
if gt.LocalConn != nil {
defer gt.LocalConn.Close()
}
if gt.RemoteConn != nil {
defer gt.RemoteConn.Close()
}
}()
logger.L().Error("dial remote failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
return
}
if isConnectable {
return
}
go io.Copy(gt.LocalConn, gt.RemoteConn)
go io.Copy(gt.RemoteConn, gt.LocalConn)
return
}
type GateWayManager struct {
gatewayTunnels map[string]*GatewayTunnel
sshClients map[int]*ssh.Client
sshClientsCount map[int]int
mtx sync.Mutex
}
func (gm *GateWayManager) Open(isConnectable bool, sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) {
if gateway == nil {
err = fmt.Errorf("gateway is nil")
return
}
gm.mtx.Lock()
defer gm.mtx.Unlock()
sshCli, ok := gm.sshClients[gateway.Id]
if !ok {
var auth ssh.AuthMethod
auth, err = gm.getAuth(gateway)
if err != nil {
return
}
sshCli, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", gateway.Host, gateway.Port), &ssh.ClientConfig{
User: gateway.Account,
Auth: []ssh.AuthMethod{auth},
Timeout: time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
logger.L().Error("open gateway sshcli failed", zap.Int("gatewayId", gateway.Id), zap.Error(err))
return
}
go func() {
logger.L().Debug("ssh proxy wait closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait()))
delete(gm.sshClients, gateway.Id)
}()
}
gm.sshClients[gateway.Id] = sshCli
gm.sshClientsCount[gateway.Id] += 1
localPort, err := getAvailablePort()
if err != nil {
return
}
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", localPort))
if err != nil {
return
}
g = &GatewayTunnel{
listener: listener,
GatewayId: gateway.Id,
SessionId: sessionId,
LocalIp: "localhost",
LocalPort: localPort,
RemoteIp: remoteIp,
RemotePort: remotePort,
Opened: make(chan error),
}
gm.gatewayTunnels[sessionId] = g
go g.Open(isConnectable)
logger.L().Debug("opening gateway", zap.Any("sessionId", sessionId))
<-g.Opened
logger.L().Debug("opened gateway", zap.Any("sessionId", sessionId))
return
}
func (gm *GateWayManager) Close(sessionIds ...string) {
gm.mtx.Lock()
defer gm.mtx.Unlock()
for _, sid := range sessionIds {
gt, ok := gm.gatewayTunnels[sid]
if !ok {
return
}
gm.sshClientsCount[gt.GatewayId] -= 1
if gm.sshClientsCount[gt.GatewayId] <= 0 {
if g := gm.sshClients[gt.GatewayId]; g != nil {
g.Close()
}
delete(gm.sshClients, gt.GatewayId)
delete(gm.sshClientsCount, gt.GatewayId)
}
}
}
func (gm *GateWayManager) getAuth(gateway *model.Gateway) (ssh.AuthMethod, error) {
switch gateway.AccountType {
case model.AUTHMETHOD_PASSWORD:
return ssh.Password(gateway.Password), nil
case model.AUTHMETHOD_PUBLICKEY:
if gateway.Phrase == "" {
pk, err := ssh.ParsePrivateKey([]byte(gateway.Pk))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
} else {
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(gateway.Pk), []byte(gateway.Phrase))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
}
default:
return nil, fmt.Errorf("invalid authmethod %d", gateway.AccountType)
}
}
func getAvailablePort() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return 0, err
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return 0, err
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}