mirror of
https://github.com/veops/oneterm.git
synced 2025-10-08 00:30:12 +08:00
212 lines
5.1 KiB
Go
212 lines
5.1 KiB
Go
package gateway
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
"github.com/veops/oneterm/logger"
|
|
"github.com/veops/oneterm/model"
|
|
)
|
|
|
|
var (
|
|
manager = &GateWayManager{
|
|
gatewayTunnels: map[string]*GatewayTunnel{},
|
|
sshClients: map[int]*ssh.Client{},
|
|
sshClientsCount: map[int]int{},
|
|
mtx: sync.Mutex{},
|
|
}
|
|
)
|
|
|
|
func GetGatewayManager() *GateWayManager {
|
|
return manager
|
|
}
|
|
|
|
func GetGatewayTunnelBySessionId(sessionId string) *GatewayTunnel {
|
|
return manager.gatewayTunnels[sessionId]
|
|
}
|
|
|
|
type GatewayTunnel struct {
|
|
listener net.Listener
|
|
GatewayId int
|
|
SessionId string
|
|
LocalIp string
|
|
LocalPort int
|
|
RemoteIp string
|
|
RemotePort int
|
|
LocalConn net.Conn
|
|
RemoteConn net.Conn
|
|
Opened chan error
|
|
}
|
|
|
|
func (gt *GatewayTunnel) Open(isConnectable bool) (err error) {
|
|
go func() {
|
|
<-time.After(time.Second * 3)
|
|
logger.L().Debug("timeout 3 second close listener", zap.String("sessionId", gt.SessionId))
|
|
gt.listener.Close()
|
|
}()
|
|
defer func() {
|
|
logger.L().Debug("close listener", zap.String("sessionId", gt.SessionId), zap.Error(err))
|
|
gt.Opened <- err
|
|
}()
|
|
gt.Opened <- nil
|
|
gt.LocalConn, err = gt.listener.Accept()
|
|
if err != nil {
|
|
logger.L().Error("accept failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
|
|
return
|
|
}
|
|
remoteAddr := fmt.Sprintf("%s:%d", gt.RemoteIp, gt.RemotePort)
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
gt.RemoteConn, err = manager.sshClients[gt.GatewayId].DialContext(ctx, "tcp", remoteAddr)
|
|
if err != nil {
|
|
defer func() {
|
|
if gt.LocalConn != nil {
|
|
defer gt.LocalConn.Close()
|
|
}
|
|
if gt.RemoteConn != nil {
|
|
defer gt.RemoteConn.Close()
|
|
}
|
|
}()
|
|
logger.L().Error("dial remote failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
|
|
return
|
|
}
|
|
if isConnectable {
|
|
return
|
|
}
|
|
go io.Copy(gt.LocalConn, gt.RemoteConn)
|
|
go io.Copy(gt.RemoteConn, gt.LocalConn)
|
|
|
|
return
|
|
}
|
|
|
|
type GateWayManager struct {
|
|
gatewayTunnels map[string]*GatewayTunnel
|
|
sshClients map[int]*ssh.Client
|
|
sshClientsCount map[int]int
|
|
mtx sync.Mutex
|
|
}
|
|
|
|
func (gm *GateWayManager) Open(isConnectable bool, sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) {
|
|
if gateway == nil {
|
|
err = fmt.Errorf("gateway is nil")
|
|
return
|
|
}
|
|
gm.mtx.Lock()
|
|
defer gm.mtx.Unlock()
|
|
|
|
sshCli, ok := gm.sshClients[gateway.Id]
|
|
if !ok {
|
|
var auth ssh.AuthMethod
|
|
auth, err = gm.getAuth(gateway)
|
|
if err != nil {
|
|
return
|
|
}
|
|
sshCli, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", gateway.Host, gateway.Port), &ssh.ClientConfig{
|
|
User: gateway.Account,
|
|
Auth: []ssh.AuthMethod{auth},
|
|
Timeout: time.Second,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
})
|
|
if err != nil {
|
|
logger.L().Error("open gateway sshcli failed", zap.Int("gatewayId", gateway.Id), zap.Error(err))
|
|
return
|
|
}
|
|
go func() {
|
|
logger.L().Debug("ssh proxy wait closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait()))
|
|
delete(gm.sshClients, gateway.Id)
|
|
}()
|
|
}
|
|
gm.sshClients[gateway.Id] = sshCli
|
|
gm.sshClientsCount[gateway.Id] += 1
|
|
localPort, err := getAvailablePort()
|
|
if err != nil {
|
|
return
|
|
}
|
|
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "localhost", localPort))
|
|
if err != nil {
|
|
return
|
|
}
|
|
g = &GatewayTunnel{
|
|
listener: listener,
|
|
GatewayId: gateway.Id,
|
|
SessionId: sessionId,
|
|
LocalIp: "localhost",
|
|
LocalPort: localPort,
|
|
RemoteIp: remoteIp,
|
|
RemotePort: remotePort,
|
|
Opened: make(chan error),
|
|
}
|
|
gm.gatewayTunnels[sessionId] = g
|
|
go g.Open(isConnectable)
|
|
|
|
logger.L().Debug("opening gateway", zap.Any("sessionId", sessionId))
|
|
<-g.Opened
|
|
logger.L().Debug("opened gateway", zap.Any("sessionId", sessionId))
|
|
|
|
return
|
|
}
|
|
|
|
func (gm *GateWayManager) Close(sessionIds ...string) {
|
|
gm.mtx.Lock()
|
|
defer gm.mtx.Unlock()
|
|
for _, sid := range sessionIds {
|
|
gt, ok := gm.gatewayTunnels[sid]
|
|
if !ok {
|
|
return
|
|
}
|
|
gm.sshClientsCount[gt.GatewayId] -= 1
|
|
if gm.sshClientsCount[gt.GatewayId] <= 0 {
|
|
if g := gm.sshClients[gt.GatewayId]; g != nil {
|
|
g.Close()
|
|
}
|
|
delete(gm.sshClients, gt.GatewayId)
|
|
delete(gm.sshClientsCount, gt.GatewayId)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (gm *GateWayManager) getAuth(gateway *model.Gateway) (ssh.AuthMethod, error) {
|
|
switch gateway.AccountType {
|
|
case model.AUTHMETHOD_PASSWORD:
|
|
return ssh.Password(gateway.Password), nil
|
|
case model.AUTHMETHOD_PUBLICKEY:
|
|
if gateway.Phrase == "" {
|
|
pk, err := ssh.ParsePrivateKey([]byte(gateway.Pk))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ssh.PublicKeys(pk), nil
|
|
} else {
|
|
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(gateway.Pk), []byte(gateway.Phrase))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ssh.PublicKeys(pk), nil
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("invalid authmethod %d", gateway.AccountType)
|
|
}
|
|
}
|
|
|
|
func getAvailablePort() (int, error) {
|
|
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
l, err := net.ListenTCP("tcp", addr)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer l.Close()
|
|
|
|
return l.Addr().(*net.TCPAddr).Port, nil
|
|
}
|