feat(api): redis

This commit is contained in:
ttk
2024-10-16 19:30:48 +08:00
parent 0aadda2f10
commit eebad0b616
6 changed files with 241 additions and 33 deletions

View File

@@ -2,6 +2,7 @@ package controller
import (
"bufio"
"bytes"
"errors"
"fmt"
"net/http"
@@ -11,16 +12,19 @@ import (
"time"
"unicode/utf8"
"github.com/charmbracelet/lipgloss/table"
"github.com/gin-gonic/gin"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/nicksnyder/go-i18n/v2/i18n"
"github.com/redis/go-redis/v9"
"github.com/samber/lo"
"github.com/spf13/cast"
"go.uber.org/zap"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
mysqlDriver "gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/veops/oneterm/acl"
@@ -43,7 +47,14 @@ var (
return true
},
}
clear = []byte("\x15\r")
byteClearAll = []byte("\x15\r")
byteClearCur = []byte("\b\x1b[J")
byteDel = []byte{'\x7f'}
byteR = []byte{'\r'}
byteN = []byte{'\n'}
byteRN = []byte{'\r', '\n'}
reRedis = regexp.MustCompile(`("[^"]*"|'[^']*'|\S+)`)
)
func read(sess *gsession.Session) error {
@@ -64,7 +75,7 @@ func read(sess *gsession.Session) error {
switch t {
case websocket.TextMessage:
chs.InChan <- msg
if (sess.IsSsh() && len(msg) > 0 && msg[0] != '9') || (!sess.IsSsh() && guacd.IsActive(msg)) {
if (sess.IsGuacd() && len(msg) > 0 && msg[0] != '9') || (!sess.IsGuacd() && guacd.IsActive(msg)) {
sess.SetIdle()
}
}
@@ -81,14 +92,14 @@ func write(sess *gsession.Session) (err error) {
out := chs.OutBuf.Bytes()
if sess.SessionType == model.SESSIONTYPE_WEB && sess.Ws != nil {
if len(out) > 0 || !strings.Contains(sess.Protocol, "ssh") {
if len(out) > 0 || !sess.IsGuacd() {
err = sess.Ws.WriteMessage(websocket.TextMessage, out)
}
} else if sess.SessionType == model.SESSIONTYPE_CLIENT && len(out) > 0 {
_, err = sess.CliRw.Write(out)
}
if sess.SshRecoder != nil && len(out) > 0 && strings.Contains(sess.Protocol, "ssh") {
if sess.SshRecoder != nil && len(out) > 0 && !sess.IsGuacd() {
sess.SshRecoder.Write(out)
}
@@ -105,7 +116,7 @@ func writeErrMsg(sess *gsession.Session, msg string) {
write(sess)
}
func HandleSsh(sess *gsession.Session) (err error) {
func HandleTerm(sess *gsession.Session) (err error) {
defer func() {
logger.L().Debug("defer HandleSsh", zap.String("sessionId", sess.SessionId))
sess.SshParser.WriteDb()
@@ -170,8 +181,8 @@ func HandleSsh(sess *gsession.Session) (err error) {
}
if cmd, forbidden := sess.SshParser.AddInput(in); forbidden {
writeErrMsg(sess, fmt.Sprintf("%s is forbidden\n", cmd))
sess.SshParser.AddInput(clear)
chs.Win.Write(clear)
sess.SshParser.AddInput(byteClearAll)
chs.Win.Write(byteClearAll)
continue
}
if _, err = chs.Win.Write(in); err != nil {
@@ -296,7 +307,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
return
}
}
if sess.IsSsh() {
if !sess.IsGuacd() {
w, h := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h"))
sess.SshParser = gsession.NewParser(sess.SessionId, w, h)
if err = mysql.DB.Model(sess.SshParser.Cmds).Where("id IN ? AND enable=?", []int(asset.AccessAuth.CmdIds), true).
@@ -330,6 +341,10 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
switch strings.Split(sess.Protocol, ":")[0] {
case "ssh":
go connectSsh(ctx, sess, asset, account, gateway)
case "redis":
go connectRedis(ctx, sess, asset, account, gateway)
case "mysql":
go connectMysql(ctx, sess, asset, account, gateway)
case "vnc", "rdp":
go connectGuacd(ctx, sess, asset, account, gateway)
default:
@@ -492,7 +507,6 @@ func connectGuacd(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
if len(p) <= 0 {
continue
}
chs.OutChan <- p
}
}
@@ -515,6 +529,171 @@ func connectGuacd(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
return
}
func connectRedis(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, account *model.Account, gateway *model.Gateway) (err error) {
chs := sess.Chans
defer func() {
ggateway.GetGatewayManager().Close(sess.SessionId)
if err != nil {
chs.ErrChan <- err
}
}()
ip, port, err := util.Proxy(false, sess.SessionId, "redis", asset, gateway)
if err != nil {
return
}
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", ip, port),
Password: account.Password,
DialTimeout: time.Second,
})
_, err = rdb.Ping(ctx).Result()
if err != nil {
return
}
chs.ErrChan <- err
sess.G.Go(func() error {
reader := bufio.NewReader(chs.Rin)
buf := &bytes.Buffer{}
pt := ""
ss := strings.Split(sess.Protocol, ":")
if len(ss) == 2 {
pt = ss[1]
}
prompt := fmt.Sprintf("%s@%s:%s> ", account.Account, asset.Name, pt)
chs.OutChan <- append(byteRN, []byte(prompt)...)
for {
select {
case <-sess.Gctx.Done():
return nil
default:
rn, size, err := reader.ReadRune()
if err != nil {
return err
}
if size <= 0 || rn == utf8.RuneError {
continue
}
p := make([]byte, utf8.RuneLen(rn))
utf8.EncodeRune(p, rn)
for bytes.HasSuffix(p, byteDel) {
p = p[:len(p)-1]
chs.OutChan <- byteClearCur
if buf.Len() > 0 {
buf.Truncate(buf.Len() - 1)
}
}
if len(p) <= 0 {
continue
}
chs.OutChan <- p
buf.Write(p)
bs := buf.Bytes()
if idx := bytes.LastIndex(bs, byteClearAll); idx >= 0 {
buf.Reset()
continue
}
if idx := bytes.LastIndex(bs, byteR); idx < 0 {
continue
}
bs = bs[:len(bs)-1]
if bytes.Equal(bs, []byte("exit")) {
sess.SshParser.AddOutput(byteRN)
sess.Once.Do(func() { close(chs.AwayChan) })
return nil
}
buf.Reset()
var res any
if len(bs) > 0 {
parts := lo.Map(reRedis.FindAllString(string(bs), -1), func(p string, _ int) any { return p })
res, err = rdb.Do(ctx, parts...).Result()
}
chs.OutChan <- []byte(fmt.Sprintf("\n%s\r\n%s", lo.Ternary[any](err == nil, lo.Ternary(res == nil, "", res), err), prompt))
}
}
})
sess.G.Go(func() error {
defer rdb.Close()
for {
select {
case <-sess.Gctx.Done():
return nil
case <-chs.AwayChan:
return fmt.Errorf("away")
}
}
})
sess.G.Wait()
return
}
func connectMysql(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, account *model.Account, gateway *model.Gateway) (err error) {
chs := sess.Chans
defer func() {
ggateway.GetGatewayManager().Close(sess.SessionId)
if err != nil {
chs.ErrChan <- err
}
}()
ip, port, err := util.Proxy(false, sess.SessionId, "mysql", asset, gateway)
if err != nil {
return
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)?charset=utf8mb4&parseTime=True&loc=Local", account.Account, account.Password, ip, port)
db, err := gorm.Open(mysqlDriver.Open(dsn))
if err != nil {
return
}
chs.ErrChan <- err
sess.G.Go(func() error {
buf := &bytes.Buffer{}
prompt := fmt.Sprintf("%s@%s:%d> ", account.Account, asset.Name, port)
sess.Chans.OutChan <- []byte(prompt)
for {
select {
case <-sess.Gctx.Done():
return nil
case <-chs.AwayChan:
return fmt.Errorf("away")
case in := <-chs.InChan:
sess.Chans.OutChan <- in
buf.Write(in)
if idx := bytes.LastIndex(in, byteN); idx < 0 {
continue
}
rows, err := db.WithContext(ctx).Raw(buf.String()).Rows()
if err != nil {
sess.Chans.OutChan <- []byte(fmt.Sprintf("%s\n\n%s", err, prompt))
} else {
heads, _ := rows.Columns()
rs := make([][]string, 0)
for rows.Next() {
r := make([]any, 0)
rows.Scan(&r)
rs = append(rs, lo.Map(r, func(v any, _ int) string { return cast.ToString(v) }))
}
t := table.New().Headers(heads...).Rows(rs...)
sess.Chans.OutChan <- []byte(fmt.Sprintf("%s\n\n%s", t.String(), prompt))
}
}
}
})
sess.G.Wait()
return
}
// Connect godoc
//
// @Tags connect
@@ -546,10 +725,10 @@ func (c *Controller) Connect(ctx *gin.Context) {
return
}
if sess.IsSsh() {
HandleSsh(sess)
} else {
if sess.IsGuacd() {
handleGuacd(sess)
} else {
HandleTerm(sess)
}
}
@@ -589,13 +768,13 @@ func (c *Controller) ConnectMonitor(ctx *gin.Context) {
}
g, gctx := errgroup.WithContext(ctx)
if !sess.IsSsh() {
if sess.IsGuacd() {
g.Go(func() error {
return monitGuacd(ctx, sess, chs, ws)
})
}
key := fmt.Sprintf("%d-%s-%d", currentUser.Uid, sessionId, time.Now().Nanosecond())
key := fmt.Sprintf("%d-%s-%d", currentUser.GetUid(), sessionId, time.Now().Nanosecond())
sess.Monitors.Store(key, ws)
defer sess.Monitors.Delete(key)
@@ -609,7 +788,7 @@ func (c *Controller) ConnectMonitor(ctx *gin.Context) {
if err != nil {
return err
}
if !sess.IsSsh() {
if sess.IsGuacd() {
chs.InChan <- p
}
}
@@ -775,11 +954,11 @@ func checkTime(data model.AccessAuth) bool {
func handleError(ctx *gin.Context, sess *gsession.Session, err error, ws *websocket.Conn, chs *gsession.SessionChans) {
defer func() {
if chs == nil {
close(sess.Chans.AwayChan)
} else {
close(chs.AwayChan)
ch := sess.Chans.AwayChan
if chs != nil {
ch = chs.AwayChan
}
sess.Once.Do(func() { close(ch) })
}()
if err == nil {
@@ -787,9 +966,9 @@ func handleError(ctx *gin.Context, sess *gsession.Session, err error, ws *websoc
}
logger.L().Debug("", zap.String("session_id", sess.SessionId), zap.Error(err))
ae, ok := err.(*ApiError)
if sess.IsSsh() {
writeErrMsg(sess, lo.Ternary(ok, ae.MessageWithCtx(ctx), err.Error()))
} else {
if sess.IsGuacd() {
ws.WriteMessage(websocket.TextMessage, guacd.NewInstruction("error", lo.Ternary(ok, (ae).MessageBase64(ctx), err.Error()), cast.ToString(ErrAdminClose)).Bytes())
} else {
writeErrMsg(sess, lo.Ternary(ok, ae.MessageWithCtx(ctx), err.Error()))
}
}