feat: guacd-gateway

This commit is contained in:
ttk
2024-02-27 18:21:57 +08:00
parent c3569a9fad
commit dfc93433ff
5 changed files with 56 additions and 23 deletions

View File

@@ -56,3 +56,8 @@ sshServer:
account: test account: test
password: 135790 password: 135790
xtoken: 123456 xtoken: 123456
guacd:
ip: oneterm-guacd
port: 4822
gateway: oneterm-api

View File

@@ -133,6 +133,7 @@ type SshServer struct {
type Guacd struct { type Guacd struct {
Ip string `yaml:"ip"` Ip string `yaml:"ip"`
Port int `yaml:"port"` Port int `yaml:"port"`
Gateway string `yaml:"gateway"`
} }
type ConfigYaml struct { type ConfigYaml struct {

View File

@@ -31,6 +31,7 @@ import (
"github.com/veops/oneterm/pkg/server/guacd" "github.com/veops/oneterm/pkg/server/guacd"
"github.com/veops/oneterm/pkg/server/model" "github.com/veops/oneterm/pkg/server/model"
"github.com/veops/oneterm/pkg/server/storage/db/mysql" "github.com/veops/oneterm/pkg/server/storage/db/mysql"
"github.com/veops/oneterm/pkg/util"
) )
var ( var (
@@ -410,7 +411,7 @@ func newSshReq(ctx *gin.Context, action int) *model.SshReq {
func connectGuacd(ctx *gin.Context, protocol string, chs *model.SessionChans) { func connectGuacd(ctx *gin.Context, protocol string, chs *model.SessionChans) {
w, h, dpi := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h")), cast.ToInt(ctx.Query("dpi")) w, h, dpi := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h")), cast.ToInt(ctx.Query("dpi"))
w, h, dpi = 731, 929, 96 //TODO w, h, dpi = 746, 929, 96 //TODO
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
var err error var err error
@@ -419,7 +420,7 @@ func connectGuacd(ctx *gin.Context, protocol string, chs *model.SessionChans) {
}() }()
asset, account, gateway := &model.Asset{}, &model.Account{}, &model.Gateway{} asset, account, gateway := &model.Asset{}, &model.Account{}, &model.Gateway{}
if err := mysql.DB.Model(&asset).Where("id = ?", ctx.Param("asset_id")).First(asset).Error; err != nil { if err = mysql.DB.Model(&asset).Where("id = ?", ctx.Param("asset_id")).First(asset).Error; err != nil {
logger.L.Error("find asset failed", zap.Error(err)) logger.L.Error("find asset failed", zap.Error(err))
return return
} }
@@ -428,15 +429,18 @@ func connectGuacd(ctx *gin.Context, protocol string, chs *model.SessionChans) {
logger.L.Error(err.Error()) logger.L.Error(err.Error())
return return
} }
if err := mysql.DB.Model(&account).Where("id = ?", ctx.Param("account_id")).First(account).Error; err != nil { if err = mysql.DB.Model(&account).Where("id = ?", ctx.Param("account_id")).First(account).Error; err != nil {
logger.L.Error("find account failed", zap.Error(err)) logger.L.Error("find account failed", zap.Error(err))
return return
} }
if asset.GatewayId != 0 { if asset.GatewayId != 0 {
if err := mysql.DB.Model(&gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil { if err = mysql.DB.Model(&gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil {
logger.L.Error("find gateway failed", zap.Error(err)) logger.L.Error("find gateway failed", zap.Error(err))
return return
} }
gateway.Password = util.DecryptAES(gateway.Password)
gateway.Pk = util.DecryptAES(gateway.Pk)
gateway.Phrase = util.DecryptAES(gateway.Phrase)
} }
t, err := guacd.NewTunnel("", w, h, dpi, protocol, asset, account, gateway) t, err := guacd.NewTunnel("", w, h, dpi, protocol, asset, account, gateway)
@@ -850,7 +854,8 @@ func (c *Controller) TestConnect(ctx *gin.Context) {
NickName: "", NickName: "",
}, },
}) })
ctx.Params = append(ctx.Params, gin.Param{Key: "asset_id", Value: "1"}, gin.Param{Key: "account_id", Value: "1"}, gin.Param{Key: "protocol", Value: "rdp:13389"}) // ctx.Params = append(ctx.Params, gin.Param{Key: "asset_id", Value: "1"}, gin.Param{Key: "account_id", Value: "1"}, gin.Param{Key: "protocol", Value: "rdp:13389"})
ctx.Params = append(ctx.Params, gin.Param{Key: "asset_id", Value: "1"}, gin.Param{Key: "account_id", Value: "3"}, gin.Param{Key: "protocol", Value: "vnc:15901"})
c.Connect(ctx) c.Connect(ctx)
} }
@@ -902,7 +907,7 @@ func handleError(ctx *gin.Context, sessionId string, err error, ws *websocket.Co
if err == nil { if err == nil {
return return
} }
logger.L.Debug("monitor failed", zap.String("session_id", sessionId), zap.Error(err)) logger.L.Debug("", zap.String("session_id", sessionId), zap.Error(err))
ae, ok := err.(*ApiError) ae, ok := err.(*ApiError)
if !ok { if !ok {
return return

View File

@@ -16,13 +16,13 @@ import (
) )
const ( const (
recordingPath = "/playback" recordingPath = "/replay"
createRecording = "true" createRecording = "true"
ignoreCert = "true" ignoreCert = "true"
) )
var ( var (
gatewayManager = &model.GateWayManager{} gatewayManager = model.NewGateWayManager()
) )
type Configuration struct { type Configuration struct {
@@ -88,13 +88,14 @@ func NewTunnel(connectionId string, w, h, dpi int, protocol string, asset *model
} }
if t.ConnectionId == "" { if t.ConnectionId == "" {
t.SessionId = uuid.New().String() t.SessionId = uuid.New().String()
t.Config.Parameters["recording-name"] = t.SessionId
} }
if gateway != nil && connectionId == "" { if gateway != nil && gateway.Id != 0 && connectionId == "" {
t.g, err = gatewayManager.Open(t.SessionId, asset.Ip, cast.ToInt(port), gateway) t.g, err = gatewayManager.Open(t.SessionId, asset.Ip, cast.ToInt(port), gateway)
if err != nil { if err != nil {
return t, err return t, err
} }
t.Config.Parameters["hostname"] = t.g.LocalIp t.Config.Parameters["hostname"] = conf.Cfg.Guacd.Gateway
t.Config.Parameters["port"] = cast.ToString(t.g.LocalPort) t.Config.Parameters["port"] = cast.ToString(t.g.LocalPort)
} }
@@ -207,5 +208,5 @@ func (t *Tunnel) assert(opcode string) (instruction *Instruction, err error) {
} }
func (t *Tunnel) Close() { func (t *Tunnel) Close() {
t.g.Close() gatewayManager.Close(t.g.Id, t.SessionId)
} }

View File

@@ -7,6 +7,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/veops/oneterm/pkg/conf"
"github.com/veops/oneterm/pkg/logger"
"go.uber.org/zap"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"gorm.io/plugin/soft_delete" "gorm.io/plugin/soft_delete"
) )
@@ -63,18 +66,21 @@ type GatewayCount struct {
} }
type GatewayTunnel struct { type GatewayTunnel struct {
Id int
LocalIp string LocalIp string
LocalPort int LocalPort int
listener net.Listener listener net.Listener
localConnections map[string]net.Conn localConnections map[string]net.Conn
remoteConnections map[string]net.Conn remoteConnections map[string]net.Conn
sshClient *ssh.Client sshClient *ssh.Client
using bool
} }
func (gt *GatewayTunnel) Open(sessionId, remoteIp string, remotePort int) error { func (gt *GatewayTunnel) Open(sessionId, remoteIp string, remotePort int) error {
for { for {
lc, err := gt.listener.Accept() lc, err := gt.listener.Accept()
if err != nil { if err != nil {
logger.L.Error("accept failed", zap.Error(err))
return err return err
} }
gt.localConnections[sessionId] = lc gt.localConnections[sessionId] = lc
@@ -82,6 +88,7 @@ func (gt *GatewayTunnel) Open(sessionId, remoteIp string, remotePort int) error
remoteAddr := fmt.Sprintf("%s:%d", remoteIp, remotePort) remoteAddr := fmt.Sprintf("%s:%d", remoteIp, remotePort)
rc, err := gt.sshClient.Dial("tcp", remoteAddr) rc, err := gt.sshClient.Dial("tcp", remoteAddr)
if err != nil { if err != nil {
logger.L.Error("dial remote failed", zap.Error(err))
return err return err
} }
gt.remoteConnections[sessionId] = rc gt.remoteConnections[sessionId] = rc
@@ -95,10 +102,14 @@ func (gt *GatewayTunnel) Close(sessionId string) {
if c, ok := gt.remoteConnections[sessionId]; ok { if c, ok := gt.remoteConnections[sessionId]; ok {
c.Close() c.Close()
} }
delete(gt.localConnections, sessionId)
if c, ok := gt.localConnections[sessionId]; ok { if c, ok := gt.localConnections[sessionId]; ok {
c.Close() c.Close()
} }
delete(gt.remoteConnections, sessionId)
gt.using = len(gt.localConnections) > 0 && len(gt.remoteConnections) > 0
} }
type GateWayManager struct { type GateWayManager struct {
@@ -106,6 +117,13 @@ type GateWayManager struct {
mtx sync.Mutex mtx sync.Mutex
} }
func NewGateWayManager() *GateWayManager {
return &GateWayManager{
gateways: map[int]*GatewayTunnel{},
mtx: sync.Mutex{},
}
}
func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gateway *Gateway) (g *GatewayTunnel, err error) { func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gateway *Gateway) (g *GatewayTunnel, err error) {
gm.mtx.Lock() gm.mtx.Lock()
defer gm.mtx.Unlock() defer gm.mtx.Unlock()
@@ -134,32 +152,37 @@ func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gatew
if err != nil { if err != nil {
return return
} }
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localPort)) listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", conf.Cfg.Guacd.Gateway, localPort))
if err != nil { if err != nil {
return return
} }
g = &GatewayTunnel{ g = &GatewayTunnel{
LocalIp: "127.0.0.1", Id: gateway.Id,
LocalIp: conf.Cfg.Guacd.Gateway,
LocalPort: localPort, LocalPort: localPort,
listener: listener, listener: listener,
localConnections: map[string]net.Conn{}, localConnections: map[string]net.Conn{},
remoteConnections: map[string]net.Conn{}, remoteConnections: map[string]net.Conn{},
sshClient: sshClient, sshClient: sshClient,
using: true,
} }
err = g.Open(sessionId, remoteIp, remotePort) go g.Open(sessionId, remoteIp, remotePort)
return return
} }
func (gm *GateWayManager) Close(id int) { func (gm *GateWayManager) Close(id int, sessionId string) {
gm.mtx.Lock() gm.mtx.Lock()
defer gm.mtx.Unlock() defer gm.mtx.Unlock()
g, ok := gm.gateways[id] g, ok := gm.gateways[id]
if ok { if ok {
g.Close() g.Close(sessionId)
} }
if !g.using {
defer g.sshClient.Close()
delete(gm.gateways, id) delete(gm.gateways, id)
}
} }
func (gm *GateWayManager) getAuth(gateway *Gateway) (ssh.AuthMethod, error) { func (gm *GateWayManager) getAuth(gateway *Gateway) (ssh.AuthMethod, error) {
@@ -195,9 +218,7 @@ func getAvailablePort() (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer l.Close()
defer func(l *net.TCPListener) {
_ = l.Close()
}(l)
return l.Addr().(*net.TCPAddr).Port, nil return l.Addr().(*net.TCPAddr).Port, nil
} }