mirror of
https://github.com/veops/oneterm.git
synced 2025-09-27 11:42:08 +08:00
413 lines
12 KiB
Go
413 lines
12 KiB
Go
package connector
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gliderlabs/ssh"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/samber/lo"
|
|
"github.com/spf13/cast"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/sync/errgroup"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/veops/oneterm/internal/acl"
|
|
"github.com/veops/oneterm/internal/connector/protocols"
|
|
"github.com/veops/oneterm/internal/connector/protocols/db"
|
|
"github.com/veops/oneterm/internal/model"
|
|
"github.com/veops/oneterm/internal/repository"
|
|
"github.com/veops/oneterm/internal/service"
|
|
fileservice "github.com/veops/oneterm/internal/service/file"
|
|
gsession "github.com/veops/oneterm/internal/session"
|
|
myErrors "github.com/veops/oneterm/pkg/errors"
|
|
"github.com/veops/oneterm/pkg/logger"
|
|
)
|
|
|
|
var (
|
|
byteClearAll = []byte("\x15\r")
|
|
)
|
|
|
|
func Connect(ctx *gin.Context) {
|
|
ctx.Set("sessionType", model.SESSIONTYPE_WEB)
|
|
|
|
ws, err := protocols.Upgrader.Upgrade(ctx.Writer, ctx.Request, http.Header{
|
|
"sec-websocket-protocol": {ctx.GetHeader("sec-websocket-protocol")},
|
|
})
|
|
if err != nil {
|
|
ctx.AbortWithError(http.StatusInternalServerError, err)
|
|
return
|
|
}
|
|
defer ws.Close()
|
|
|
|
if shareErr, exists := ctx.Get("shareErr"); exists && shareErr != nil {
|
|
if apiErr, ok := shareErr.(*myErrors.ApiError); ok {
|
|
errMsg := apiErr.MessageWithCtx(ctx)
|
|
ws.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("\r\n \033[31m Invalid share link: %s \x1b[0m", errMsg)))
|
|
return
|
|
}
|
|
}
|
|
|
|
var sess *gsession.Session
|
|
defer func() {
|
|
protocols.HandleError(ctx, sess, err, ws, nil)
|
|
}()
|
|
|
|
sess, err = DoConnect(ctx, ws)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if sess.IsGuacd() {
|
|
protocols.HandleGuacd(sess)
|
|
} else {
|
|
HandleTerm(sess, ctx)
|
|
}
|
|
}
|
|
|
|
func ConnectMonitor(ctx *gin.Context) {
|
|
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
|
|
|
sessionId := ctx.Param("session_id")
|
|
var sess *gsession.Session
|
|
ws, err := protocols.Upgrader.Upgrade(ctx.Writer, ctx.Request, http.Header{
|
|
"sec-websocket-protocol": {ctx.GetHeader("sec-websocket-protocol")},
|
|
})
|
|
if err != nil {
|
|
ctx.AbortWithError(http.StatusInternalServerError, err)
|
|
return
|
|
}
|
|
defer ws.Close()
|
|
|
|
chs := gsession.NewSessionChans()
|
|
defer func() {
|
|
protocols.HandleError(ctx, sess, err, ws, chs)
|
|
}()
|
|
|
|
if !acl.IsAdmin(currentUser) {
|
|
ctx.AbortWithError(http.StatusBadRequest, &myErrors.ApiError{Code: myErrors.ErrNoPerm, Data: map[string]any{"perm": "monitor session"}})
|
|
return
|
|
}
|
|
|
|
if sess = gsession.GetOnlineSessionById(sessionId); sess == nil {
|
|
err = &myErrors.ApiError{Code: myErrors.ErrInvalidSessionId, Data: map[string]any{"sessionId": sessionId}}
|
|
return
|
|
}
|
|
|
|
g, gctx := errgroup.WithContext(ctx)
|
|
if sess.IsGuacd() {
|
|
g.Go(func() error {
|
|
return protocols.MonitGuacd(ctx, sess, chs, ws)
|
|
})
|
|
}
|
|
|
|
key := fmt.Sprintf("%d-%s-%d", currentUser.GetUid(), sessionId, time.Now().Nanosecond())
|
|
sess.Monitors.Store(key, ws)
|
|
defer sess.Monitors.Delete(key)
|
|
|
|
g.Go(func() error {
|
|
for {
|
|
select {
|
|
case <-gctx.Done():
|
|
return nil
|
|
default:
|
|
_, p, err := ws.ReadMessage()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if sess.IsGuacd() {
|
|
chs.InChan <- p
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
g.Wait()
|
|
logger.L().Info("monitor exit", zap.String("sessionId", sess.SessionId))
|
|
}
|
|
|
|
func ConnectClose(ctx *gin.Context) {
|
|
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
|
if !acl.IsAdmin(currentUser) {
|
|
ctx.AbortWithError(http.StatusBadRequest, &myErrors.ApiError{Code: myErrors.ErrNoPerm, Data: map[string]any{"perm": "close session"}})
|
|
return
|
|
}
|
|
|
|
sessionService := service.NewSessionService()
|
|
session, err := sessionService.GetOnlineSessionByID(ctx, ctx.Param("session_id"))
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
ctx.JSON(http.StatusOK, gin.H{"code": 0, "message": "ok"})
|
|
return
|
|
}
|
|
if err != nil {
|
|
ctx.AbortWithError(http.StatusBadRequest, &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": "invalid session id"}})
|
|
return
|
|
}
|
|
|
|
logger.L().Info("closing...", zap.String("sessionId", session.SessionId), zap.Int("type", session.SessionType))
|
|
defer protocols.OfflineSession(ctx, session.SessionId, currentUser.GetUserName())
|
|
|
|
session.Status = model.SESSIONSTATUS_OFFLINE
|
|
session.ClosedAt = lo.ToPtr(time.Now())
|
|
gsession.UpsertSession(session)
|
|
|
|
ctx.JSON(http.StatusOK, gin.H{"code": 0, "message": "ok"})
|
|
}
|
|
|
|
// DoConnect handles the connection setup process
|
|
func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, err error) {
|
|
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
|
|
|
assetId, accountId := cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))
|
|
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
sessionId := ctx.Query("session_id")
|
|
if sessionId == "" {
|
|
sessionId = uuid.New().String()
|
|
}
|
|
|
|
sess = gsession.NewSession(ctx)
|
|
sess.Ws = ws
|
|
sess.Session = &model.Session{
|
|
SessionType: ctx.GetInt("sessionType"),
|
|
SessionId: sessionId,
|
|
Uid: currentUser.GetUid(),
|
|
UserName: currentUser.GetUserName(),
|
|
AssetId: assetId,
|
|
Asset: asset,
|
|
AssetInfo: fmt.Sprintf("%s(%s)", asset.Name, asset.Ip),
|
|
AccountId: accountId,
|
|
AccountInfo: fmt.Sprintf("%s(%s)", account.Name, account.Account),
|
|
GatewayId: asset.GatewayId,
|
|
GatewayInfo: lo.Ternary(asset.GatewayId == 0, "", fmt.Sprintf("%s(%s)", gateway.Name, gateway.Host)),
|
|
Protocol: ctx.Param("protocol"),
|
|
Status: model.SESSIONSTATUS_ONLINE,
|
|
ShareId: cast.ToInt(ctx.Value("shareId")),
|
|
}
|
|
if sess.ShareId != 0 {
|
|
sess.ShareEnd, _ = ctx.Value("shareEnd").(time.Time)
|
|
if shareErr, exists := ctx.Get("shareErr"); exists && shareErr != nil {
|
|
if apiErr, ok := shareErr.(*myErrors.ApiError); ok {
|
|
err = apiErr
|
|
return
|
|
}
|
|
}
|
|
}
|
|
if !sess.IsGuacd() {
|
|
w, h := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h"))
|
|
sess.SshParser = gsession.NewParser(sess.SessionId, w, h)
|
|
sess.SshParser.Protocol = sess.Protocol
|
|
|
|
sessionService := service.NewSessionService()
|
|
cmds, err := sessionService.GetSshParserCommands(ctx, []int(asset.AccessAuth.CmdIds))
|
|
if err != nil {
|
|
return sess, err
|
|
}
|
|
sess.SshParser.Cmds = cmds
|
|
|
|
for _, c := range sess.SshParser.Cmds {
|
|
if c.IsRe {
|
|
c.Re, _ = regexp.Compile(c.Cmd)
|
|
}
|
|
}
|
|
if sess.SshRecoder, err = gsession.NewAsciinema(sess.SessionId, w, h); err != nil {
|
|
return sess, err
|
|
}
|
|
}
|
|
if sess.SessionType == model.SESSIONTYPE_WEB {
|
|
sess.ClientIp = ctx.ClientIP()
|
|
} else if sess.SessionType == model.SESSIONTYPE_CLIENT {
|
|
sess.ClientIp = ctx.RemoteIP()
|
|
}
|
|
|
|
if !protocols.CheckTime(asset.AccessAuth) {
|
|
err = &myErrors.ApiError{Code: myErrors.ErrAccessTime}
|
|
return
|
|
}
|
|
if ok, err := service.DefaultAuthService.HasAuthorization(ctx, sess); err != nil {
|
|
err = &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": err}}
|
|
return sess, err
|
|
} else if !ok {
|
|
err = &myErrors.ApiError{Code: myErrors.ErrUnauthorized, Data: map[string]any{"perm": "connect"}}
|
|
return sess, err
|
|
}
|
|
|
|
switch strings.Split(sess.Protocol, ":")[0] {
|
|
case "ssh":
|
|
go protocols.ConnectSsh(ctx, sess, asset, account, gateway)
|
|
case "redis", "mysql", "mongodb", "postgresql":
|
|
go db.ConnectDB(sess, asset, account, gateway)
|
|
case "telnet":
|
|
go protocols.ConnectTelnet(ctx, sess, asset, account, gateway)
|
|
case "vnc", "rdp":
|
|
go protocols.ConnectGuacd(ctx, sess, asset, account, gateway)
|
|
default:
|
|
logger.L().Error("wrong protocol " + sess.Protocol)
|
|
}
|
|
|
|
if err = <-sess.Chans.ErrChan; err != nil {
|
|
logger.L().Error("failed to connect", zap.Error(err))
|
|
err = &myErrors.ApiError{Code: myErrors.ErrConnectServer, Data: map[string]any{"err": err}}
|
|
return
|
|
}
|
|
|
|
gsession.GetOnlineSession().Store(sess.SessionId, sess)
|
|
gsession.UpsertSession(sess)
|
|
|
|
// Initialize session-based file client for high-performance file operations
|
|
// Only for SSH-based protocols that support SFTP
|
|
protocol := strings.Split(sess.Protocol, ":")[0]
|
|
if protocol == "ssh" {
|
|
if err := fileservice.DefaultFileService.InitSessionFileClient(sess.SessionId, sess.AssetId, sess.AccountId); err != nil {
|
|
logger.L().Warn("Failed to initialize session file client",
|
|
zap.String("sessionId", sess.SessionId),
|
|
zap.Int("assetId", sess.AssetId),
|
|
zap.Int("accountId", sess.AccountId),
|
|
zap.Error(err))
|
|
// Don't fail the session creation for file service initialization failure
|
|
} else {
|
|
logger.L().Info("Session file client initialized successfully",
|
|
zap.String("sessionId", sess.SessionId),
|
|
zap.Int("assetId", sess.AssetId),
|
|
zap.Int("accountId", sess.AccountId))
|
|
}
|
|
} else if protocol == "rdp" || protocol == "vnc" {
|
|
logger.L().Debug("Skipping session file client initialization for Guacamole protocol",
|
|
zap.String("protocol", protocol),
|
|
zap.String("sessionId", sess.SessionId))
|
|
// RDP and VNC use Guacamole protocol for file transfer, not SSH/SFTP
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// HandleTerm handles terminal sessions
|
|
func HandleTerm(sess *gsession.Session, ctx *gin.Context) (err error) {
|
|
defer func() {
|
|
logger.L().Debug("defer HandleTerm", zap.String("sessionId", sess.SessionId))
|
|
|
|
// Clean up session-based file client (only for SSH-based protocols)
|
|
protocol := strings.Split(sess.Protocol, ":")[0]
|
|
if protocol == "ssh" {
|
|
fileservice.DefaultFileService.CloseSessionFileClient(sess.SessionId)
|
|
// Clear SSH client from session to ensure proper cleanup
|
|
sess.ClearSSHClient()
|
|
}
|
|
|
|
// Close SSH recorder to save recording file
|
|
if sess.SshRecoder != nil {
|
|
if closeErr := sess.SshRecoder.Close(); closeErr != nil {
|
|
logger.L().Error("Failed to close SSH recorder", zap.String("sessionId", sess.SessionId), zap.Error(closeErr))
|
|
}
|
|
}
|
|
|
|
sess.SshParser.Close(sess.Prompt)
|
|
sess.Status = model.SESSIONSTATUS_OFFLINE
|
|
sess.ClosedAt = lo.ToPtr(time.Now())
|
|
if err = gsession.UpsertSession(sess); err != nil {
|
|
logger.L().Error("upsert session failed", zap.Error(err))
|
|
}
|
|
}()
|
|
chs := sess.Chans
|
|
tk, tk1s, tk1m := time.NewTicker(time.Millisecond*100), time.NewTicker(time.Second), time.NewTicker(time.Minute)
|
|
assetService := service.NewAssetService()
|
|
sess.G.Go(func() error {
|
|
return protocols.Read(sess)
|
|
})
|
|
sess.G.Go(func() (err error) {
|
|
defer sess.Chans.Rin.Close()
|
|
defer sess.Chans.Wout.Close()
|
|
for {
|
|
select {
|
|
case <-sess.Gctx.Done():
|
|
protocols.Write(sess)
|
|
return
|
|
case <-chs.AwayChan:
|
|
return
|
|
case <-sess.IdleTk.C:
|
|
msg := (&myErrors.ApiError{Code: myErrors.ErrIdleTimeout, Data: map[string]any{"second": model.GlobalConfig.Load().Timeout}}).MessageWithCtx(ctx)
|
|
protocols.WriteErrMsg(sess, msg)
|
|
return &myErrors.ApiError{Code: myErrors.ErrIdleTimeout, Data: map[string]any{"second": model.GlobalConfig.Load().Timeout}}
|
|
case <-tk1m.C:
|
|
asset, err := assetService.GetById(sess.Gctx, sess.AssetId)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if protocols.CheckTime(asset.AccessAuth) && (sess.ShareId == 0 || time.Now().Before(sess.ShareEnd)) {
|
|
continue
|
|
}
|
|
return &myErrors.ApiError{Code: myErrors.ErrAccessTime}
|
|
case closeBy := <-chs.CloseChan:
|
|
msg := (&myErrors.ApiError{Code: myErrors.ErrAdminClose, Data: map[string]any{"admin": closeBy}}).MessageWithCtx(ctx)
|
|
protocols.WriteErrMsg(sess, msg)
|
|
logger.L().Info("closed by", zap.String("admin", closeBy))
|
|
return &myErrors.ApiError{Code: myErrors.ErrAdminClose, Data: map[string]any{"admin": closeBy}}
|
|
case err = <-chs.ErrChan:
|
|
protocols.WriteErrMsg(sess, err.Error())
|
|
return
|
|
case in := <-chs.InChan:
|
|
if sess.SessionType == model.SESSIONTYPE_WEB {
|
|
rt := in[0]
|
|
msg := in[1:]
|
|
switch rt {
|
|
case '1':
|
|
in = msg
|
|
case '9':
|
|
continue
|
|
case 'w':
|
|
wh := strings.Split(string(msg), ",")
|
|
if len(wh) < 2 {
|
|
continue
|
|
}
|
|
chs.WindowChan <- ssh.Window{
|
|
Width: cast.ToInt(wh[0]),
|
|
Height: cast.ToInt(wh[1]),
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
if cmd, forbidden := sess.SshParser.AddInput(in); forbidden {
|
|
protocols.WriteErrMsg(sess, fmt.Sprintf("%s is forbidden\n", cmd))
|
|
sess.SshParser.AddInput(byteClearAll)
|
|
chs.Win.Write(byteClearAll)
|
|
continue
|
|
}
|
|
if _, err = chs.Win.Write(in); err != nil {
|
|
return
|
|
}
|
|
case out := <-chs.OutChan:
|
|
if _, err = chs.OutBuf.Write(out); err != nil {
|
|
return
|
|
}
|
|
sess.SshParser.AddOutput(out)
|
|
case <-tk.C:
|
|
if err = protocols.Write(sess); err != nil {
|
|
return
|
|
}
|
|
case <-tk1s.C:
|
|
if sess.Ws == nil {
|
|
continue
|
|
}
|
|
if err = sess.Ws.WriteMessage(websocket.TextMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
if err = sess.G.Wait(); err != nil {
|
|
logger.L().Debug("handle term wait end", zap.String("id", sess.SessionId), zap.Error(err))
|
|
}
|
|
|
|
return
|
|
}
|