fix(api): monitor, replay

This commit is contained in:
ttk
2024-09-06 19:45:01 +08:00
parent 5a63cc375a
commit a30bd94541
6 changed files with 75 additions and 96 deletions

View File

@@ -2,7 +2,6 @@ package controller
import (
"bufio"
"context"
"errors"
"fmt"
"net/http"
@@ -171,13 +170,15 @@ func HandleSsh(sess *gsession.Session) (err error) {
case <-tk.C:
write(sess)
case <-tk1s.C:
write(sess)
if sess.Ws != nil {
sess.Ws.WriteMessage(websocket.TextMessage, nil)
}
}
}
})
if err = sess.G.Wait(); err != nil {
logger.L().Debug("sess wait end", zap.String("id", sess.SessionId), zap.Error(err))
logger.L().Debug("sess wait end ssh", zap.String("id", sess.SessionId), zap.Error(err))
}
return
@@ -227,7 +228,7 @@ func handleGuacd(sess *gsession.Session) (err error) {
})
if err = sess.G.Wait(); err != nil {
logger.L().Debug("sess wait end", zap.String("id", sess.SessionId), zap.Error(err))
logger.L().Debug("sess wait end guacd", zap.String("id", sess.SessionId), zap.Error(err))
}
return
@@ -244,7 +245,7 @@ func writeToMonitors(monitors *sync.Map, out []byte) {
})
}
func DoConnect(ctx *gin.Context) (sess *gsession.Session, err error) {
func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
assetId, accountId := cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))
@@ -252,13 +253,9 @@ func DoConnect(ctx *gin.Context) (sess *gsession.Session, err error) {
if err != nil {
return
}
if !checkAuthorization(currentUser, asset, accountId) || !checkTime(asset.AccessAuth) {
err = &ApiError{Code: ErrLogin, Data: map[string]any{"err": fmt.Errorf("invalid authorization")}}
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
sess = gsession.NewSession(ctx)
sess.Ws = ws
sess.Session = &model.Session{
SessionType: ctx.GetInt("sessionType"),
SessionId: uuid.New().String(),
@@ -273,12 +270,29 @@ func DoConnect(ctx *gin.Context) (sess *gsession.Session, err error) {
Protocol: ctx.Param("protocol"),
Status: model.SESSIONSTATUS_ONLINE,
}
if sess.IsSsh() {
w, h := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h"))
if sess.SshRecoder, err = gsession.NewAsciinema(sess.SessionId, w, h); err != nil {
return
}
}
if sess.SessionType == model.SESSIONTYPE_WEB {
sess.ClientIp = ctx.ClientIP()
} else if sess.SessionType == model.SESSIONTYPE_CLIENT {
sess.ClientIp = ctx.RemoteIP()
}
if !checkTime(asset.AccessAuth) {
err = &ApiError{Code: ErrAccessTime}
ctx.AbortWithError(http.StatusBadRequest, err)
return
}
if !checkAuthorization(currentUser, asset, accountId) {
err = &ApiError{Code: ErrLogin}
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
switch strings.Split(sess.Protocol, ":")[0] {
case "ssh":
go connectSsh(ctx, sess, asset, account, gateway)
@@ -355,10 +369,6 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
return
}
if sess.SshRecoder, err = gsession.NewAsciinema(sess.SessionId, w, h); err != nil {
return
}
sess.G.Go(func() error {
err = sshSess.Wait()
return fmt.Errorf("ssh session wait end %w", err)
@@ -411,28 +421,6 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
return
}
// func newGuacdSession(ctx *gin.Context, connectionId, sessionId string, asset *model.Asset, account *model.Account, gateway *model.Gateway) *gsession.Session {
// currentUser, _ := acl.GetSessionFromCtx(ctx)
// return &gsession.Session{
// Session: &model.Session{
// SessionType: model.SESSIONTYPE_WEB,
// SessionId: sessionId,
// Uid: currentUser.GetUid(),
// UserName: currentUser.GetUserName(),
// AssetId: asset.Id,
// AssetInfo: fmt.Sprintf("%s(%s)", asset.Name, asset.Ip),
// AccountId: account.Id,
// AccountInfo: fmt.Sprintf("%s(%s)", account.Name, account.Account),
// GatewayId: gateway.Id,
// GatewayInfo: lo.Ternary(gateway.Id == 0, "", fmt.Sprintf("%s:%d", gateway.Host, gateway.Port)),
// ClientIp: ctx.ClientIP(),
// Protocol: ctx.Param("protocol"),
// Status: model.SESSIONSTATUS_ONLINE,
// },
// ConnectionId: connectionId,
// }
// }
func connectGuacd(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, account *model.Account, gateway *model.Gateway) (err error) {
chs := sess.Chans
defer func() {
@@ -443,7 +431,7 @@ func connectGuacd(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
w, h, dpi := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h")), cast.ToInt(ctx.Query("dpi"))
t, err := guacd.NewTunnel("", w, h, dpi, sess.Protocol, asset, account, gateway)
t, err := guacd.NewTunnel("", sess.SessionId, w, h, dpi, sess.Protocol, asset, account, gateway)
if err != nil {
logger.L().Error("guacd tunnel failed", zap.Error(err))
return
@@ -460,9 +448,6 @@ func connectGuacd(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
return nil
default:
p, err := t.Read()
// if isCtxDone(sess.Gctx) {
// return nil
// }
if err != nil {
return err
}
@@ -503,10 +488,6 @@ func connectGuacd(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
// @Router /connect/:asset_id/:account_id/:protocol [post]
func (c *Controller) Connect(ctx *gin.Context) {
ctx.Set("sessionType", model.SESSIONTYPE_WEB)
sess, err := DoConnect(ctx)
if err != nil {
return
}
ws, err := Upgrader.Upgrade(ctx.Writer, ctx.Request, http.Header{
"sec-websocket-protocol": {ctx.GetHeader("sec-websocket-protocol")},
@@ -516,12 +497,17 @@ func (c *Controller) Connect(ctx *gin.Context) {
return
}
defer ws.Close()
sess.Ws = ws
var sess *gsession.Session
defer func() {
handleError(ctx, sess, err, ws)
handleError(ctx, sess, err, ws, nil)
}()
sess, err = DoConnect(ctx, ws)
if err != nil {
return
}
if sess.IsSsh() {
HandleSsh(sess)
} else {
@@ -549,8 +535,9 @@ func (c *Controller) ConnectMonitor(ctx *gin.Context) {
}
defer ws.Close()
chs := gsession.NewSessionChans()
defer func() {
handleError(ctx, sess, err, ws)
handleError(ctx, sess, err, ws, chs)
}()
if !acl.IsAdmin(currentUser) {
@@ -564,25 +551,15 @@ func (c *Controller) ConnectMonitor(ctx *gin.Context) {
}
g, gctx := errgroup.WithContext(ctx)
chs := gsession.NewSessionChans()
if !sess.IsSsh() {
g.Go(func() error {
return monitGuacd(ctx, sess, ws)
return monitGuacd(ctx, sess, chs, ws)
})
}
key := fmt.Sprintf("%d-%s-%d", currentUser.Uid, sessionId, time.Now().Nanosecond())
sess.Monitors.Store(key, ws)
defer func() {
sess.Monitors.Delete(key)
if sess.IsSsh() {
if sess.SessionType == model.SESSIONTYPE_CLIENT && !sess.HasMonitors() {
close(chs.AwayChan)
}
} else {
close(chs.AwayChan)
}
}()
defer sess.Monitors.Delete(key)
g.Go(func() error {
for {
@@ -601,22 +578,24 @@ func (c *Controller) ConnectMonitor(ctx *gin.Context) {
}
})
err = g.Wait()
if err = g.Wait(); err != nil {
logger.L().Error("monitor failed", zap.Error(err))
}
}
func monitGuacd(ctx *gin.Context, sess *gsession.Session, ws *websocket.Conn) (err error) {
connectionId, chs := sess.ConnectionId, gsession.NewSessionChans()
func monitGuacd(ctx *gin.Context, sess *gsession.Session, chs *gsession.SessionChans, ws *websocket.Conn) (err error) {
w, h, dpi := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h")), cast.ToInt(ctx.Query("dpi"))
defer func() {
chs.ErrChan <- err
}()
t, err := guacd.NewTunnel(connectionId, w, h, dpi, ":", nil, nil, nil)
t, err := guacd.NewTunnel(sess.ConnectionId, "", w, h, dpi, ":", nil, nil, nil)
if err != nil {
logger.L().Error("guacd tunnel failed", zap.Error(err))
return
}
defer t.Disconnect()
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
@@ -640,13 +619,11 @@ func monitGuacd(ctx *gin.Context, sess *gsession.Session, ws *websocket.Conn) (e
g.Go(func() error {
for {
select {
case closeBy := <-chs.CloseChan:
err := fmt.Errorf("colse by admin %s", closeBy)
case <-sess.Chans.AwayChan:
err := fmt.Errorf("monitored session closed")
ws.WriteMessage(websocket.TextMessage, guacd.NewInstruction("disconnect", err.Error()).Bytes())
logger.L().Warn(err.Error())
return err
case err := <-chs.ErrChan:
logger.L().Error("disconnected", zap.Error(err))
return err
case out := <-chs.OutChan:
ws.WriteMessage(websocket.TextMessage, out)
@@ -762,9 +739,13 @@ func checkAuthorization(user *acl.Session, asset *model.Asset, accountId int) bo
return acl.IsAdmin(user) || lo.Contains(asset.Authorization[accountId], user.GetRid())
}
func handleError(ctx *gin.Context, sess *gsession.Session, err error, ws *websocket.Conn) {
func handleError(ctx *gin.Context, sess *gsession.Session, err error, ws *websocket.Conn, chs *gsession.SessionChans) {
defer func() {
close(sess.Chans.AwayChan)
if chs == nil {
close(sess.Chans.AwayChan)
} else {
close(chs.AwayChan)
}
}()
if err == nil {
@@ -772,23 +753,10 @@ func handleError(ctx *gin.Context, sess *gsession.Session, err error, ws *websoc
}
logger.L().Debug("", zap.String("session_id", sess.SessionId), zap.Error(err))
ae, ok := err.(*ApiError)
if !ok {
return
}
if sess.IsSsh() {
ws.WriteMessage(websocket.TextMessage, []byte(ae.MessageWithCtx(ctx)))
writeErrMsg(sess, lo.Ternary(ok, ae.MessageWithCtx(ctx), err.Error()))
} else {
ws.WriteMessage(websocket.TextMessage, guacd.NewInstruction("error", (ae).MessageBase64(ctx), cast.ToString(ErrAdminClose)).Bytes())
}
// ctx.AbortWithError(http.StatusBadRequest, err)
}
func isCtxDone(ctx context.Context) bool {
select {
case _, ok := <-ctx.Done():
return !ok
default:
return false
ws.WriteMessage(websocket.TextMessage, guacd.NewInstruction("error", lo.Ternary(ok, (ae).MessageBase64(ctx), err.Error()), cast.ToString(ErrAdminClose)).Bytes())
}
}

View File

@@ -83,6 +83,9 @@ func (ae *ApiError) Message(localizer *i18n.Localizer) (msg string) {
}
func (ae *ApiError) MessageWithCtx(ctx *gin.Context) string {
if ae == nil {
return ""
}
lang := ctx.PostForm("lang")
accept := ctx.GetHeader("Accept-Language")
localizer := i18n.NewLocalizer(myi18n.Bundle, lang, accept)

View File

@@ -277,25 +277,29 @@ func toListData[T any](data []T) *ListData {
}
}
func nodeCountAsset() (res map[int]int64, err error) {
func nodeCountAsset() (m map[int]int64, err error) {
assets := make([]*model.AssetIdPid, 0)
if err = mysql.DB.Model(&model.Asset{}).Find(&assets).Error; err != nil {
return
}
res = make(map[int]int64)
nodes := make([]*model.NodeIdPid, 0)
if err = mysql.DB.Model(&model.Node{}).Find(&nodes).Error; err != nil {
return
}
m = make(map[int]int64)
for _, a := range assets {
res[a.ParentId] += 1
m[a.ParentId] += 1
}
g := make(map[int][]int)
for _, n := range assets {
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int) int64
dfs = func(x int) int64 {
for _, y := range g[x] {
res[x] += dfs(y)
m[x] += dfs(y)
}
return res[x]
return m[x]
}
dfs(0)

View File

@@ -7,7 +7,6 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/veops/oneterm/conf"
@@ -44,7 +43,7 @@ type Tunnel struct {
gw *ggateway.GatewayTunnel
}
func NewTunnel(connectionId string, w, h, dpi int, protocol string, asset *model.Asset, account *model.Account, gateway *model.Gateway) (t *Tunnel, err error) {
func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, asset *model.Asset, account *model.Account, gateway *model.Gateway) (t *Tunnel, err error) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", conf.Cfg.Guacd.Host, conf.Cfg.Guacd.Port), time.Second*3)
if err != nil {
return
@@ -88,7 +87,7 @@ func NewTunnel(connectionId string, w, h, dpi int, protocol string, asset *model
},
}
if t.ConnectionId == "" {
t.SessionId = uuid.New().String()
t.SessionId = sessionId
t.Config.Parameters["recording-name"] = t.SessionId
}
if gateway != nil && gateway.Id != 0 && t.ConnectionId == "" {

View File

@@ -672,7 +672,8 @@ func (m Model) View() string {
v += styleText(m.echoTransform(string(value[pos+1:]))) // text after cursor
v += m.completionView(0) // suggested completion
} else {
if m.canAcceptSuggestion() {
if m.canAcceptSuggestion() && len(m.matchedSuggestions) <= 1 {
suggestion := m.matchedSuggestions[m.currentSuggestionIndex]
if len(value) < len(suggestion) {
m.Cursor.TextStyle = m.CompletionStyle

View File

@@ -175,7 +175,11 @@ func (m *view) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
case errMsg:
if msg != nil {
return m, tea.Printf(" [ERROR] %s\n\n", errStyle.Render(msg.Error()))
str := msg.Error()
if ae, ok := msg.(*controller.ApiError); ok {
str = controller.Err2Msg[ae.Code].One
}
return m, tea.Printf(" [ERROR] %s\n\n", errStyle.Render(str))
}
}
m.textinput, tiCmd = m.textinput.Update(msg)
@@ -327,7 +331,7 @@ func (conn *connector) SetStderr(w io.Writer) {
}
func (conn *connector) Run() error {
gsess, err := controller.DoConnect(conn.Ctx)
gsess, err := controller.DoConnect(conn.Ctx, nil)
if err != nil {
return err
}