Files
oneterm/backend/internal/api/controller/connect/other.go

170 lines
4.1 KiB
Go

package connect
import (
"bufio"
"bytes"
"database/sql"
"fmt"
"strings"
"time"
"unicode/utf8"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/lipgloss/table"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/samber/lo"
"github.com/spf13/cast"
mysqlDriver "gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/veops/oneterm/internal/model"
gsession "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/internal/tunneling"
)
// connectOther connects to other protocols (MySQL, Redis, etc.)
func connectOther(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, account *model.Account, gateway *model.Gateway) (err error) {
chs := sess.Chans
defer func() {
if err != nil {
chs.ErrChan <- err
}
}()
protocol := strings.Split(sess.Protocol, ":")[0]
ip, port, err := tunneling.Proxy(false, sess.SessionId, protocol, asset, gateway)
if err != nil {
return
}
var (
rdb *redis.Client
db *gorm.DB
)
switch protocol {
case "redis":
rdb = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", ip, port),
Password: account.Password,
DialTimeout: time.Second,
})
_, err = rdb.Ping(ctx).Result()
if err != nil {
return
}
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&parseTime=True&loc=Local", account.Account, account.Password, ip, port)
db, err = gorm.Open(mysqlDriver.Open(dsn))
if err != nil {
return
}
}
chs.ErrChan <- err
sess.G.Go(func() error {
reader := bufio.NewReader(chs.Rin)
buf := &bytes.Buffer{}
pt := ""
ss := strings.Split(sess.Protocol, ":")
if len(ss) == 2 {
pt = ss[1]
}
sess.Prompt = fmt.Sprintf("%s@%s:%s> ", account.Account, asset.Name, pt)
chs.OutChan <- append(byteRN, []byte(sess.Prompt)...)
for {
select {
case <-sess.Gctx.Done():
return nil
default:
rn, size, err := reader.ReadRune()
if err != nil {
return err
}
if size <= 0 || rn == utf8.RuneError {
continue
}
p := make([]byte, utf8.RuneLen(rn))
utf8.EncodeRune(p, rn)
p = bytes.ReplaceAll(p, byteT, byteS)
for bytes.HasSuffix(p, byteDel) {
p = p[:len(p)-1]
if buf.Len() > 0 {
var dels []byte
last, ok := lo.Last([]rune(buf.String()))
for i := 0; ok && i < lipgloss.Width(string(last)); i++ {
dels = append(dels, byteClearCur...)
}
chs.OutChan <- dels
buf.Truncate(buf.Len() - len([]byte(string(last))))
}
}
if len(p) <= 0 {
continue
}
chs.OutChan <- p
buf.Write(p)
bs := buf.Bytes()
if idx := bytes.LastIndex(bs, byteClearAll); idx >= 0 {
buf.Reset()
continue
}
if idx := bytes.LastIndex(bs, byteR); idx < 0 {
continue
}
bs = bs[:len(bs)-1]
if bytes.Equal(bs, []byte("exit")) {
sess.Once.Do(func() { close(chs.AwayChan) })
return nil
}
buf.Reset()
var (
res any
rows *sql.Rows
)
if len(bs) > 0 {
switch protocol {
case "redis":
parts := lo.Map(reRedis.FindAllString(string(bs), -1), func(p string, _ int) any { return p })
res, err = rdb.Do(ctx, parts...).Result()
case "mysql":
if rows, err = db.WithContext(ctx).Raw(string(bs)).Rows(); err == nil {
heads, _ := rows.Columns()
n := len(heads)
rs := make([][]string, 0)
for rows.Next() {
r := make([]any, n)
r = lo.Map(r, func(v any, _ int) any { return new(any) })
if err = rows.Scan(r...); err != nil {
break
}
rs = append(rs, lo.Map(r, func(v any, i int) string { return cast.ToString(v) }))
}
res = strings.ReplaceAll(table.New().Border(border).Headers(heads...).Rows(rs...).String(), "\n", "\r\n")
}
}
}
chs.OutChan <- []byte(fmt.Sprintf("\n%s\r\n%s", lo.Ternary[any](err == nil, lo.Ternary(res == nil, "", res), err), sess.Prompt))
err = nil
}
}
})
sess.G.Go(func() (err error) {
for {
select {
case <-sess.Gctx.Done():
return
case <-chs.AwayChan:
return
case <-chs.WindowChan:
continue
}
}
})
sess.G.Wait()
return
}