refactor(backend): optimize code organization structure

This commit is contained in:
pycook
2025-03-02 21:36:24 +08:00
parent fd6986ab73
commit 57291e6737
81 changed files with 979 additions and 677 deletions

View File

@@ -1,142 +0,0 @@
package api
import (
"bytes"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/nicksnyder/go-i18n/v2/i18n"
"go.uber.org/zap"
"github.com/veops/oneterm/acl"
"github.com/veops/oneterm/api/controller"
myi18n "github.com/veops/oneterm/i18n"
"github.com/veops/oneterm/logger"
)
var (
errUnauthorized = &controller.ApiError{Code: controller.ErrUnauthorized}
)
func ginLogger() gin.HandlerFunc {
return func(ctx *gin.Context) {
start := time.Now()
ctx.Next()
cost := time.Since(start)
logger.L().Info(ctx.Request.URL.String(),
zap.String("method", ctx.Request.Method),
zap.Int("status", ctx.Writer.Status()),
zap.String("ip", ctx.ClientIP()),
zap.Duration("cost", cost),
)
}
}
func auth() gin.HandlerFunc {
return func(ctx *gin.Context) {
var (
sess *acl.Session
err error
cookie string
)
m := make(map[string]any)
ctx.ShouldBindBodyWithJSON(&m)
if ctx.Request.Method == "GET" {
if _, ok := ctx.GetQuery("_key"); ok {
m["_key"] = ctx.Query("_key")
m["_secret"] = ctx.Query("_secret")
}
}
if _, ok := m["_key"]; ok {
sess, err = acl.AuthWithKey(ctx.Request.URL.Path, m)
if err != nil {
logger.L().Error("cannot authwithkey", zap.Error(err))
ctx.AbortWithError(http.StatusUnauthorized, errUnauthorized)
return
}
ctx.Set("isAuthWithKey", true)
} else {
cookie, err = ctx.Cookie("session")
if err != nil || cookie == "" {
logger.L().Error("cannot get cookie.session", zap.Error(err))
ctx.AbortWithError(http.StatusUnauthorized, errUnauthorized)
return
}
sess, err = acl.ParseCookie(cookie)
}
if err != nil {
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
ctx.Set("session", sess)
ctx.Next()
}
}
func authAdmin() gin.HandlerFunc {
return func(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
if !acl.IsAdmin(currentUser) {
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
}
}
type bodyWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (w bodyWriter) Write(b []byte) (int, error) {
return w.body.Write(b)
}
func Error2Resp() gin.HandlerFunc {
return func(ctx *gin.Context) {
if strings.Contains(ctx.Request.URL.String(), "session/replay") {
ctx.Next()
return
}
wb := &bodyWriter{
body: &bytes.Buffer{},
ResponseWriter: ctx.Writer,
}
ctx.Writer = wb
ctx.Next()
obj := make(map[string]any)
json.Unmarshal(wb.body.Bytes(), &obj)
if len(ctx.Errors) > 0 {
if v, ok := obj["code"]; !ok || v == 0 {
obj["code"] = ctx.Writer.Status()
}
if v, ok := obj["message"]; !ok || v == "" {
e := ctx.Errors.Last().Err
obj["message"] = e.Error()
ae, ok := e.(*controller.ApiError)
if ok {
lang := ctx.PostForm("lang")
accept := ctx.GetHeader("Accept-Language")
localizer := i18n.NewLocalizer(myi18n.Bundle, lang, accept)
obj["message"] = ae.Message(localizer)
}
}
}
bs, _ := json.Marshal(obj)
wb.ResponseWriter.Write(bs)
}
}

View File

@@ -7,11 +7,12 @@ import (
"syscall" "syscall"
"github.com/oklog/run" "github.com/oklog/run"
"github.com/veops/oneterm/api"
"github.com/veops/oneterm/logger"
"github.com/veops/oneterm/schedule"
"github.com/veops/oneterm/sshsrv"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/internal/api"
"github.com/veops/oneterm/internal/schedule"
"github.com/veops/oneterm/internal/sshsrv"
"github.com/veops/oneterm/pkg/logger"
) )
func main() { func main() {

View File

@@ -19,6 +19,20 @@ mysql:
user: root user: root
password: root password: root
database:
type: mysql # alternative: postgres, tidb, tdsql, dm
host: oneterm-mysql
port: 3306
user: root
password: root
database: oneterm
charset: utf8mb4
max_idle_conns: 10
max_open_conns: 100
conn_max_lifetime: 3600 # seconds
conn_max_idle_time: 600 # seconds
ssl_mode: disable
redis: redis:
addr: oneterm-redis:6379 addr: oneterm-redis:6379
password: root password: root

View File

@@ -1,52 +0,0 @@
package mysql
import (
"fmt"
"strings"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/veops/oneterm/conf"
"github.com/veops/oneterm/logger"
"github.com/veops/oneterm/model"
)
var (
DB *gorm.DB
)
func init() {
var err error
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/oneterm?charset=utf8mb4&parseTime=True&loc=Local",
conf.Cfg.Mysql.User, conf.Cfg.Mysql.Password, conf.Cfg.Mysql.Host, conf.Cfg.Mysql.Port)
DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
if err != nil {
logger.L().Fatal("init mysql failed", zap.Error(err))
}
err = DB.AutoMigrate(
model.DefaultAccount, model.DefaultAsset, model.DefaultAuthorization, model.DefaultCommand,
model.DefaultConfig, model.DefaultFileHistory, model.DefaultGateway, model.DefaultHistory,
model.DefaultNode, model.DefaultPublicKey, model.DefaultSession, model.DefaultSessionCmd,
model.DefaultShare,
)
if err != nil {
logger.L().Fatal("auto migrate mysql failed", zap.Error(err))
}
dropIndexs := map[string]any{
"asset_account_id_del": &model.Authorization{},
}
for k, v := range dropIndexs {
if !DB.Migrator().HasIndex(v, k) {
continue
}
if err = DB.Migrator().DropIndex(v, k); err != nil && !strings.Contains(err.Error(), "1091") {
logger.L().Fatal("drop index failed", zap.Error(err))
}
}
}

View File

@@ -2,7 +2,7 @@ FROM golang:alpine
WORKDIR /oneterm WORKDIR /oneterm
COPY . . COPY . .
RUN go env -w GOPROXY=https://goproxy.cn,direct \ RUN go env -w GOPROXY=https://goproxy.cn,direct \
&& go build --ldflags "-s -w" -o ./build/oneterm ./main.go && go build --ldflags "-s -w" -o ./build/oneterm ./cmd/server/main.go
FROM alpine:latest FROM alpine:latest
RUN set -eux && sed -i 's/dl-cdn.alpinelinux.org/mirrors.ustc.edu.cn/g' /etc/apk/repositories RUN set -eux && sed -i 's/dl-cdn.alpinelinux.org/mirrors.ustc.edu.cn/g' /etc/apk/repositories
@@ -10,8 +10,8 @@ RUN apk add tzdata
ENV TZ=Asia/Shanghai ENV TZ=Asia/Shanghai
ENV TERM=xterm-256color ENV TERM=xterm-256color
WORKDIR /oneterm WORKDIR /oneterm
COPY --from=0 /oneterm/config.example.yaml ./config.yaml COPY --from=0 /oneterm/configs/config.example.yaml ./config.yaml
COPY --from=0 /oneterm/i18n/translate ./translate COPY --from=0 /oneterm/internal/i18n/locales ./locales
COPY --from=0 /oneterm/build/oneterm . COPY --from=0 /oneterm/build/oneterm .
CMD [ "./oneterm","run","-c","./config.yaml"] CMD [ "./oneterm","run","-c","./config.yaml"]

View File

@@ -43,6 +43,11 @@ require (
github.com/charmbracelet/x/term v0.1.1 // indirect github.com/charmbracelet/x/term v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.1.0 // indirect github.com/charmbracelet/x/windows v0.1.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/golang/snappy v0.0.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-localereader v0.0.1 // indirect
@@ -52,6 +57,7 @@ require (
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
gorm.io/driver/postgres v1.5.11 // indirect
) )
require ( require (

View File

@@ -91,6 +91,8 @@ github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -100,6 +102,14 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -301,6 +311,7 @@ golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
@@ -342,6 +353,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314=
gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI=
gorm.io/driver/sqlite v1.1.3 h1:BYfdVuZB5He/u9dt4qDpZqiqDJ6KhPqs5QUqsr/Eeuc= gorm.io/driver/sqlite v1.1.3 h1:BYfdVuZB5He/u9dt4qDpZqiqDJ6KhPqs5QUqsr/Eeuc=
gorm.io/driver/sqlite v1.1.3/go.mod h1:AKDgRWk8lcSQSw+9kxCJnX/yySj8G3rdwYlU57cB45c= gorm.io/driver/sqlite v1.1.3/go.mod h1:AKDgRWk8lcSQSw+9kxCJnX/yySj8G3rdwYlU57cB45c=
gorm.io/gorm v1.20.1/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.20.1/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw=

View File

@@ -14,8 +14,8 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/remote" "github.com/veops/oneterm/pkg/remote"
) )
// SigningAlgorithm provides interfaces to generate and verify signature // SigningAlgorithm provides interfaces to generate and verify signature
@@ -158,7 +158,7 @@ func AuthWithKey(path string, originData map[string]any) (sess *Session, err err
payload[k] = v payload[k] = v
} }
body["payload"] = payload body["payload"] = payload
url := fmt.Sprintf("%s%s", conf.Cfg.Auth.Acl.Url, "/acl/auth_with_key") url := fmt.Sprintf("%s%s", config.Cfg.Auth.Acl.Url, "/acl/auth_with_key")
data := &AuthWithKeyResp{} data := &AuthWithKeyResp{}
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetBody(body). SetBody(body).

View File

@@ -10,17 +10,18 @@ import (
"strings" "strings"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/veops/oneterm/conf"
mysql "github.com/veops/oneterm/db"
"github.com/veops/oneterm/logger"
"github.com/veops/oneterm/model"
"github.com/veops/oneterm/remote"
"github.com/veops/oneterm/util"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
"github.com/veops/oneterm/pkg/remote"
"github.com/veops/oneterm/pkg/utils"
) )
func LoginByPassword(ctx context.Context, username string, password string, ip string) (sess *Session, err error) { func LoginByPassword(ctx context.Context, username string, password string, ip string) (sess *Session, err error) {
url := fmt.Sprintf("%s/acl/login", conf.Cfg.Auth.Acl.Url) url := fmt.Sprintf("%s/acl/login", config.Cfg.Auth.Acl.Url)
data := make(map[string]any) data := make(map[string]any)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
@@ -56,9 +57,9 @@ func LoginByPassword(ctx context.Context, username string, password string, ip s
func LoginByPublicKey(ctx context.Context, username string, pk string, ip string) (sess *Session, err error) { func LoginByPublicKey(ctx context.Context, username string, pk string, ip string) (sess *Session, err error) {
pk = strings.TrimSpace(pk) pk = strings.TrimSpace(pk)
enc := util.EncryptAES(pk) enc := utils.EncryptAES(pk)
cnt := int64(0) cnt := int64(0)
if err = mysql.DB.Model(&model.PublicKey{}).Where("username = ? AND pk = ?", username, enc).Count(&cnt).Error; err != nil || cnt == 0 { if err = dbpkg.DB.Model(&model.PublicKey{}).Where("username = ? AND pk = ?", username, enc).Count(&cnt).Error; err != nil || cnt == 0 {
err = fmt.Errorf("%w", err) err = fmt.Errorf("%w", err)
logger.L().Warn("find pk failed", zap.Int64("cnt", cnt), zap.Error(err)) logger.L().Warn("find pk failed", zap.Int64("cnt", cnt), zap.Error(err))
return return
@@ -69,7 +70,7 @@ func LoginByPublicKey(ctx context.Context, username string, pk string, ip string
return return
} }
url := fmt.Sprintf("%s/acl/users/info", conf.Cfg.Auth.Acl.Url) url := fmt.Sprintf("%s/acl/users/info", config.Cfg.Auth.Acl.Url)
data := &UserInfoResp{} data := &UserInfoResp{}
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
@@ -119,7 +120,7 @@ func LoginByPublicKey(ctx context.Context, username string, pk string, ip string
} }
func ParseCookie(cookie string) (sess *Session, err error) { func ParseCookie(cookie string) (sess *Session, err error) {
s := NewSignature(conf.Cfg.SecretKey, "cookie-session", "", "hmac", nil, nil) s := NewSignature(config.Cfg.SecretKey, "cookie-session", "", "hmac", nil, nil)
content, err := s.Unsign(cookie) content, err := s.Unsign(cookie)
if err != nil { if err != nil {
logger.L().Error("cannot unsign", zap.Error(err)) logger.L().Error("cannot unsign", zap.Error(err))
@@ -139,7 +140,7 @@ func Logout(sess *Session) {
if sess == nil { if sess == nil {
return return
} }
url := fmt.Sprintf("%s/acl/logout", conf.Cfg.Auth.Acl.Url) url := fmt.Sprintf("%s/acl/logout", config.Cfg.Auth.Acl.Url)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetCookie(sess.Cookie). SetCookie(sess.Cookie).
Post(url) Post(url)

View File

@@ -10,11 +10,11 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/logger" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/logger"
"github.com/veops/oneterm/remote" "github.com/veops/oneterm/pkg/remote"
) )
func init() { func init() {
@@ -45,7 +45,7 @@ func migrateNode() {
} }
nodes := make([]*model.Node, 0) nodes := make([]*model.Node, 0)
if err = mysql.DB.Model(&nodes).Where("resource_id = 0").Or("resource_id IS NULL").Find(&nodes).Error; err != nil { if err = dbpkg.DB.Model(&nodes).Where("resource_id = 0").Or("resource_id IS NULL").Find(&nodes).Error; err != nil {
logger.L().Fatal("get nodes failed", zap.Error(err)) logger.L().Fatal("get nodes failed", zap.Error(err))
} }
eg := errgroup.Group{} eg := errgroup.Group{}
@@ -56,7 +56,7 @@ func migrateNode() {
if err != nil { if err != nil {
return err return err
} }
if err := mysql.DB.Model(&nd).Where("id=?", nd.Id).Update("resource_id", r.ResourceId).Error; err != nil { if err := dbpkg.DB.Model(&nd).Where("id=?", nd.Id).Update("resource_id", r.ResourceId).Error; err != nil {
return err return err
} }
return nil return nil
@@ -74,7 +74,7 @@ func GetResourceTypes(ctx context.Context) (rt []*ResourceType, err error) {
} }
data := &ResourceTypeResp{} data := &ResourceTypeResp{}
url := fmt.Sprintf("%s/acl/resource_types", conf.Cfg.Auth.Acl.Url) url := fmt.Sprintf("%s/acl/resource_types", config.Cfg.Auth.Acl.Url)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
"App-Access-Token": token, "App-Access-Token": token,
@@ -98,7 +98,7 @@ func AddResourceTypes(ctx context.Context, rt *ResourceType) (err error) {
return return
} }
url := fmt.Sprintf("%s/acl/resource_types", conf.Cfg.Auth.Acl.Url) url := fmt.Sprintf("%s/acl/resource_types", config.Cfg.Auth.Acl.Url)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
"App-Access-Token": token, "App-Access-Token": token,
@@ -117,7 +117,7 @@ func AddResource(ctx context.Context, uid int, resourceTypeId string, name strin
} }
res = &Resource{} res = &Resource{}
url := fmt.Sprintf("%s/acl/resources", conf.Cfg.Auth.Acl.Url) url := fmt.Sprintf("%s/acl/resources", config.Cfg.Auth.Acl.Url)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
"App-Access-Token": token, "App-Access-Token": token,
@@ -140,7 +140,7 @@ func DeleteResource(ctx context.Context, uid int, resourceId int) (err error) {
return return
} }
url := fmt.Sprintf("%v/acl/resources/%v", conf.Cfg.Auth.Acl.Url, resourceId) url := fmt.Sprintf("%v/acl/resources/%v", config.Cfg.Auth.Acl.Url, resourceId)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
"App-Access-Token": token, "App-Access-Token": token,
@@ -157,7 +157,7 @@ func UpdateResource(ctx context.Context, uid int, resourceId int, updates map[st
return return
} }
url := fmt.Sprintf("%s/acl/resources/%d", conf.Cfg.Auth.Acl.Url, resourceId) url := fmt.Sprintf("%s/acl/resources/%d", config.Cfg.Auth.Acl.Url, resourceId)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
"App-Access-Token": token, "App-Access-Token": token,
@@ -175,7 +175,7 @@ func GetResourcePermissions(ctx context.Context, resourceId int) (res map[string
return return
} }
res = make(map[string]*ResourcePermissionsRespItem) res = make(map[string]*ResourcePermissionsRespItem)
url := fmt.Sprintf("%v/acl/resources/%v/permissions", conf.Cfg.Auth.Acl.Url, resourceId) //TODO conf url := fmt.Sprintf("%v/acl/resources/%v/permissions", config.Cfg.Auth.Acl.Url, resourceId) //TODO config
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeader("App-Access-Token", token). SetHeader("App-Access-Token", token).
SetResult(&res). SetResult(&res).

View File

@@ -8,8 +8,8 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/remote" "github.com/veops/oneterm/pkg/remote"
) )
const ( const (
@@ -28,11 +28,11 @@ func GetRoleResources(ctx context.Context, rid int, resourceTypeId string) (res
} }
data := &ResourceResult{} data := &ResourceResult{}
url := fmt.Sprintf("%s/acl/roles/%d/resources", conf.Cfg.Auth.Acl.Url, rid) url := fmt.Sprintf("%s/acl/roles/%d/resources", config.Cfg.Auth.Acl.Url, rid)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeader("App-Access-Token", token). SetHeader("App-Access-Token", token).
SetQueryParams(map[string]string{ SetQueryParams(map[string]string{
"app_id": conf.Cfg.Auth.Acl.AppId, "app_id": config.Cfg.Auth.Acl.AppId,
"resource_type_id": resourceTypeId, "resource_type_id": resourceTypeId,
}). }).
SetResult(data). SetResult(data).
@@ -66,7 +66,7 @@ func HasPermission(ctx context.Context, rid int, resourceTypeName string, resour
} }
data := make(map[string]any) data := make(map[string]any)
url := fmt.Sprintf("%s/acl/roles/has_perm", conf.Cfg.Auth.Acl.Url) url := fmt.Sprintf("%s/acl/roles/has_perm", config.Cfg.Auth.Acl.Url)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeader("App-Access-Token", token). SetHeader("App-Access-Token", token).
SetQueryParams(map[string]string{ SetQueryParams(map[string]string{
@@ -94,7 +94,7 @@ func GrantRoleResource(ctx context.Context, uid int, roleId int, resourceId int,
return return
} }
url := fmt.Sprintf("%s/acl/roles/%d/resources/%d/grant", conf.Cfg.Auth.Acl.Url, roleId, resourceId) url := fmt.Sprintf("%s/acl/roles/%d/resources/%d/grant", config.Cfg.Auth.Acl.Url, roleId, resourceId)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
"App-Access-Token": token, "App-Access-Token": token,
@@ -113,7 +113,7 @@ func RevokeRoleResource(ctx context.Context, uid int, roleId int, resourceId int
return return
} }
url := fmt.Sprintf("%s/acl/roles/%d/resources/%d/revoke", conf.Cfg.Auth.Acl.Url, roleId, resourceId) url := fmt.Sprintf("%s/acl/roles/%d/resources/%d/revoke", config.Cfg.Auth.Acl.Url, roleId, resourceId)
resp, err := remote.RC.R(). resp, err := remote.RC.R().
SetHeaders(map[string]string{ SetHeaders(map[string]string{
"App-Access-Token": token, "App-Access-Token": token,

View File

@@ -0,0 +1,69 @@
package api
import (
"context"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"github.com/veops/oneterm/internal/api/router"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
)
var (
ctx, cancel = context.WithCancel(context.Background())
srv = &http.Server{}
)
func initDB() {
cfg := db.ConfigFromGlobal()
if err := db.Init(cfg,
model.DefaultAccount, model.DefaultAsset, model.DefaultAuthorization, model.DefaultCommand,
model.DefaultConfig, model.DefaultFileHistory, model.DefaultGateway, model.DefaultHistory,
model.DefaultNode, model.DefaultPublicKey, model.DefaultSession, model.DefaultSessionCmd,
model.DefaultShare,
); err != nil {
logger.L().Fatal("Failed to init database", zap.Error(err))
}
if err := db.DropIndex(&model.Authorization{}, "asset_account_id_del"); err != nil {
logger.L().Fatal("Failed to drop index", zap.Error(err))
}
defer db.Close()
}
func RunApi() error {
initDB()
r := gin.New()
router.SetupRouter(r)
srv.Addr = fmt.Sprintf("%s:%d", config.Cfg.Http.Host, config.Cfg.Http.Port)
srv.Handler = r
logger.L().Info("Starting HTTP server",
zap.String("address", srv.Addr))
err := srv.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
logger.L().Fatal("Start HTTP server failed", zap.Error(err))
}
return err
}
func StopApi() {
defer cancel()
logger.L().Info("Stopping HTTP server")
if err := srv.Shutdown(ctx); err != nil {
logger.L().Error("Stop HTTP server failed", zap.Error(err))
}
}

View File

@@ -11,11 +11,11 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/utils"
) )
var ( var (
@@ -38,16 +38,16 @@ var (
} }
}, },
func(ctx *gin.Context, data *model.Account) { func(ctx *gin.Context, data *model.Account) {
data.Password = util.EncryptAES(data.Password) data.Password = utils.EncryptAES(data.Password)
data.Pk = util.EncryptAES(data.Pk) data.Pk = utils.EncryptAES(data.Pk)
data.Phrase = util.EncryptAES(data.Phrase) data.Phrase = utils.EncryptAES(data.Phrase)
}, },
} }
accountPostHooks = []postHook[*model.Account]{ accountPostHooks = []postHook[*model.Account]{
func(ctx *gin.Context, data []*model.Account) { func(ctx *gin.Context, data []*model.Account) {
acs := make([]*model.AccountCount, 0) acs := make([]*model.AccountCount, 0)
if err := mysql.DB. if err := dbpkg.DB.
Model(&model.Authorization{}). Model(&model.Authorization{}).
Select("account_id AS id, COUNT(*) as count"). Select("account_id AS id, COUNT(*) as count").
Group("account_id"). Group("account_id").
@@ -63,19 +63,19 @@ var (
}, },
func(ctx *gin.Context, data []*model.Account) { func(ctx *gin.Context, data []*model.Account) {
for _, d := range data { for _, d := range data {
d.Password = util.DecryptAES(d.Password) d.Password = utils.DecryptAES(d.Password)
d.Pk = util.DecryptAES(d.Pk) d.Pk = utils.DecryptAES(d.Pk)
d.Phrase = util.DecryptAES(d.Phrase) d.Phrase = utils.DecryptAES(d.Phrase)
} }
}, },
} }
accountDcs = []deleteCheck{ accountDcs = []deleteCheck{
func(ctx *gin.Context, id int) { func(ctx *gin.Context, id int) {
assetName := "" assetName := ""
err := mysql.DB. err := dbpkg.DB.
Model(model.DefaultAsset). Model(model.DefaultAsset).
Select("name"). Select("name").
Where("id = (?)", mysql.DB.Model(&model.Authorization{}).Select("asset_id").Where("account_id = ?", id).Limit(1)). Where("id = (?)", dbpkg.DB.Model(&model.Authorization{}).Select("asset_id").Where("account_id = ?", id).Limit(1)).
First(&assetName). First(&assetName).
Error Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -95,7 +95,7 @@ var (
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /account [post] // @Router /account [post]
func (c *Controller) CreateAccount(ctx *gin.Context) { func (c *Controller) CreateAccount(ctx *gin.Context) {
doCreate(ctx, true, &model.Account{}, conf.RESOURCE_ACCOUNT, accountPreHooks...) doCreate(ctx, true, &model.Account{}, config.RESOURCE_ACCOUNT, accountPreHooks...)
} }
// DeleteAccount godoc // DeleteAccount godoc
@@ -105,7 +105,7 @@ func (c *Controller) CreateAccount(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /account/:id [delete] // @Router /account/:id [delete]
func (c *Controller) DeleteAccount(ctx *gin.Context) { func (c *Controller) DeleteAccount(ctx *gin.Context) {
doDelete(ctx, true, &model.Account{}, conf.RESOURCE_ACCOUNT, accountDcs...) doDelete(ctx, true, &model.Account{}, config.RESOURCE_ACCOUNT, accountDcs...)
} }
// UpdateAccount godoc // UpdateAccount godoc
@@ -116,7 +116,7 @@ func (c *Controller) DeleteAccount(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /account/:id [put] // @Router /account/:id [put]
func (c *Controller) UpdateAccount(ctx *gin.Context) { func (c *Controller) UpdateAccount(ctx *gin.Context) {
doUpdate(ctx, true, &model.Account{}, conf.RESOURCE_ACCOUNT, accountPreHooks...) doUpdate(ctx, true, &model.Account{}, config.RESOURCE_ACCOUNT, accountPreHooks...)
} }
// GetAccounts godoc // GetAccounts godoc
@@ -136,7 +136,7 @@ func (c *Controller) GetAccounts(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
db := mysql.DB.Model(&model.Account{}) db := dbpkg.DB.Model(&model.Account{})
db = filterEqual(ctx, db, "id", "type") db = filterEqual(ctx, db, "id", "type")
db = filterLike(ctx, db, "name") db = filterLike(ctx, db, "name")
db = filterSearch(ctx, db, "name", "account") db = filterSearch(ctx, db, "name", "account")
@@ -158,7 +158,7 @@ func (c *Controller) GetAccounts(ctx *gin.Context) {
db = db.Order("name") db = db.Order("name")
doGet(ctx, !info, db, conf.RESOURCE_ACCOUNT, accountPostHooks...) doGet(ctx, !info, db, config.RESOURCE_ACCOUNT, accountPostHooks...)
} }
func GetAccountIdsByAuthorization(ctx *gin.Context) (ids []int, err error) { func GetAccountIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
@@ -168,7 +168,7 @@ func GetAccountIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
return return
} }
ss := make([]model.Slice[string], 0) ss := make([]model.Slice[string], 0)
if err = mysql.DB.Model(model.DefaultAsset).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil { if err = dbpkg.DB.Model(model.DefaultAsset).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return return
} }

View File

@@ -11,13 +11,13 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/internal/schedule"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/schedule" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/logger"
) )
const ( const (
@@ -48,7 +48,7 @@ var (
// @Router /asset [post] // @Router /asset [post]
func (c *Controller) CreateAsset(ctx *gin.Context) { func (c *Controller) CreateAsset(ctx *gin.Context) {
asset := &model.Asset{} asset := &model.Asset{}
doCreate(ctx, true, asset, conf.RESOURCE_ASSET, assetPreHooks...) doCreate(ctx, true, asset, config.RESOURCE_ASSET, assetPreHooks...)
schedule.UpdateConnectables(asset.Id) schedule.UpdateConnectables(asset.Id)
} }
@@ -60,7 +60,7 @@ func (c *Controller) CreateAsset(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /asset/:id [delete] // @Router /asset/:id [delete]
func (c *Controller) DeleteAsset(ctx *gin.Context) { func (c *Controller) DeleteAsset(ctx *gin.Context) {
doDelete(ctx, true, &model.Asset{}, conf.RESOURCE_ASSET) doDelete(ctx, true, &model.Asset{}, config.RESOURCE_ASSET)
} }
// UpdateAsset godoc // UpdateAsset godoc
@@ -71,7 +71,7 @@ func (c *Controller) DeleteAsset(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /asset/:id [put] // @Router /asset/:id [put]
func (c *Controller) UpdateAsset(ctx *gin.Context) { func (c *Controller) UpdateAsset(ctx *gin.Context) {
doUpdate(ctx, true, &model.Asset{}, conf.RESOURCE_ASSET) doUpdate(ctx, true, &model.Asset{}, config.RESOURCE_ASSET)
schedule.UpdateConnectables(cast.ToInt(ctx.Param("id"))) schedule.UpdateConnectables(cast.ToInt(ctx.Param("id")))
} }
@@ -93,7 +93,7 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
db := mysql.DB.Model(model.DefaultAsset) db := dbpkg.DB.Model(model.DefaultAsset)
db = filterEqual(ctx, db, "id") db = filterEqual(ctx, db, "id")
db = filterLike(ctx, db, "name", "ip") db = filterLike(ctx, db, "name", "ip")
db = filterSearch(ctx, db, "name", "ip") db = filterSearch(ctx, db, "name", "ip")
@@ -124,11 +124,11 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
db = db.Order("name") db = db.Order("name")
doGet(ctx, !info, db, conf.RESOURCE_ASSET, assetPostHooks...) doGet(ctx, !info, db, config.RESOURCE_ASSET, assetPostHooks...)
} }
func assetPostHookCount(ctx *gin.Context, data []*model.Asset) { func assetPostHookCount(ctx *gin.Context, data []*model.Asset) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -160,8 +160,8 @@ func assetPostHookAuth(ctx *gin.Context, data []*model.Asset) {
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
noInfoIds := make([]int, 0) noInfoIds := make([]int, 0)
if !info { if !info {
t := mysql.DB.Model(model.DefaultAsset) t := dbpkg.DB.Model(model.DefaultAsset)
assetResIds, _ := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), conf.RESOURCE_ASSET) assetResIds, _ := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
t, _ = handleAssetIds(ctx, t, assetResIds) t, _ = handleAssetIds(ctx, t, assetResIds)
t.Pluck("id", &noInfoIds) t.Pluck("id", &noInfoIds)
} }
@@ -193,7 +193,7 @@ func assetPostHookAuth(ctx *gin.Context, data []*model.Asset) {
} }
func handleParentId(ctx context.Context, parentId int) (pids []int, err error) { func handleParentId(ctx context.Context, parentId int) (pids []int, err error) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -257,7 +257,7 @@ func getIdsByAuthorizationIds(ctx *gin.Context) (nodeIds, assetIds, accountIds [
} }
func getAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) (assetIds []int, err error) { func getAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) (assetIds []int, err error) {
assets, err := util.GetAllFromCacheDb(ctx, model.DefaultAsset) assets, err := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
if err != nil { if err != nil {
return return
} }

View File

@@ -13,13 +13,13 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/logger" gsession "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/config"
gsession "github.com/veops/oneterm/session" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/logger"
) )
// UpsertAuthorization godoc // UpsertAuthorization godoc
@@ -36,7 +36,7 @@ func (c *Controller) UpsertAuthorization(ctx *gin.Context) {
return return
} }
if err := mysql.DB.Transaction(func(tx *gorm.DB) error { if err := dbpkg.DB.Transaction(func(tx *gorm.DB) error {
t := &model.Authorization{} t := &model.Authorization{}
if err = tx.Model(t). if err = tx.Model(t).
Where("node_id=? AND asset_id=? AND account_id=?", auth.NodeId, auth.AssetId, auth.AccountId). Where("node_id=? AND asset_id=? AND account_id=?", auth.NodeId, auth.AssetId, auth.AccountId).
@@ -82,7 +82,7 @@ func (c *Controller) DeleteAuthorization(ctx *gin.Context) {
Id: cast.ToInt(ctx.Param("id")), Id: cast.ToInt(ctx.Param("id")),
} }
if err := mysql.DB.Model(auth).Where("id=?", auth.Id).First(auth); err != nil { if err := dbpkg.DB.Model(auth).Where("id=?", auth.Id).First(auth); err != nil {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument, Data: map[string]any{"err": err}})
return return
} }
@@ -92,7 +92,7 @@ func (c *Controller) DeleteAuthorization(ctx *gin.Context) {
return return
} }
if err := handleAuthorization(ctx, mysql.DB, model.ACTION_DELETE, nil, auth); err != nil { if err := handleAuthorization(ctx, dbpkg.DB, model.ACTION_DELETE, nil, auth); err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return return
} }
@@ -119,7 +119,7 @@ func (c *Controller) GetAuthorizations(ctx *gin.Context) {
AccountId: cast.ToInt(ctx.Query("account_id")), AccountId: cast.ToInt(ctx.Query("account_id")),
NodeId: cast.ToInt(ctx.Query("node_id")), NodeId: cast.ToInt(ctx.Query("node_id")),
} }
db := mysql.DB.Model(auth) db := dbpkg.DB.Model(auth)
for _, k := range []string{"node_id", "asset_id", "account_id"} { for _, k := range []string{"node_id", "asset_id", "account_id"} {
q, _ := ctx.GetQuery(k) q, _ := ctx.GetQuery(k)
db = db.Where(fmt.Sprintf("%s=?", k), cast.ToInt(q)) db = db.Where(fmt.Sprintf("%s=?", k), cast.ToInt(q))
@@ -136,7 +136,7 @@ func (c *Controller) GetAuthorizations(ctx *gin.Context) {
return return
} }
doGet[*model.Authorization](ctx, false, db, conf.RESOURCE_AUTHORIZATION) doGet[*model.Authorization](ctx, false, db, config.RESOURCE_AUTHORIZATION)
} }
func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds, assetIds, accountIds []int, err error) { func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds, assetIds, accountIds []int, err error) {
@@ -147,13 +147,13 @@ func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds,
eg.Go(func() (err error) { eg.Go(func() (err error) {
defer close(ch) defer close(ch)
res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), conf.RESOURCE_NODE) res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
if err != nil { if err != nil {
return return
} }
res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) }) res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId }) resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -164,14 +164,14 @@ func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds,
}) })
eg.Go(func() (err error) { eg.Go(func() (err error) {
res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), conf.RESOURCE_ASSET) res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
if err != nil { if err != nil {
return return
} }
res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) }) res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId }) resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
<-ch <-ch
assets, err := util.GetAllFromCacheDb(ctx, model.DefaultAsset) assets, err := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
if err != nil { if err != nil {
return return
} }
@@ -183,13 +183,13 @@ func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds,
}) })
eg.Go(func() (err error) { eg.Go(func() (err error) {
res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), conf.RESOURCE_ACCOUNT) res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), config.RESOURCE_ACCOUNT)
if err != nil { if err != nil {
return return
} }
res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) }) res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId }) resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
accounts, err := util.GetAllFromCacheDb(ctx, model.DefaultAccount) accounts, err := repository.GetAllFromCacheDb(ctx, model.DefaultAccount)
if err != nil { if err != nil {
return return
} }
@@ -231,12 +231,12 @@ func hasPermAuthorization(ctx context.Context, auth *model.Authorization, action
} }
func getAuthsByAsset(t *model.Asset) (data []*model.Authorization, err error) { func getAuthsByAsset(t *model.Asset) (data []*model.Authorization, err error) {
err = mysql.DB.Model(data).Where("asset_id=? AND account_id IN ? AND node_id=0", t.Id, lo.Without(lo.Keys(t.Authorization), 0)).Find(&data).Error err = dbpkg.DB.Model(data).Where("asset_id=? AND account_id IN ? AND node_id=0", t.Id, lo.Without(lo.Keys(t.Authorization), 0)).Find(&data).Error
return return
} }
func handleAuthorization(ctx *gin.Context, tx *gorm.DB, action int, asset *model.Asset, auths ...*model.Authorization) (err error) { func handleAuthorization(ctx *gin.Context, tx *gorm.DB, action int, asset *model.Asset, auths ...*model.Authorization) (err error) {
defer util.DeleteAllFromCacheDb(ctx, model.DefaultAuthorization) defer repository.DeleteAllFromCacheDb(ctx, model.DefaultAuthorization)
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
@@ -266,7 +266,7 @@ func handleAuthorization(ctx *gin.Context, tx *gorm.DB, action int, asset *model
if err = acl.DeleteResource(ctx, currentUser.GetUid(), p.ResourceId); err != nil { if err = acl.DeleteResource(ctx, currentUser.GetUid(), p.ResourceId); err != nil {
return return
} }
if err = mysql.DB.Model(p).Where("id=?", p.Id).Delete(p).Error; err != nil { if err = dbpkg.DB.Model(p).Where("id=?", p.Id).Delete(p).Error; err != nil {
return return
} }
return return
@@ -288,7 +288,7 @@ func handleAuthorization(ctx *gin.Context, tx *gorm.DB, action int, asset *model
case model.ACTION_CREATE: case model.ACTION_CREATE:
eg.Go(func() (err error) { eg.Go(func() (err error) {
resourceId := 0 resourceId := 0
if resourceId, err = acl.CreateAcl(ctx, currentUser, conf.RESOURCE_AUTHORIZATION, auth.GetName()); err != nil { if resourceId, err = acl.CreateAcl(ctx, currentUser, config.RESOURCE_AUTHORIZATION, auth.GetName()); err != nil {
return return
} }
if err = acl.BatchGrantRoleResource(ctx, currentUser.GetUid(), auth.Rids, resourceId, []string{acl.READ}); err != nil { if err = acl.BatchGrantRoleResource(ctx, currentUser.GetUid(), auth.Rids, resourceId, []string{acl.READ}); err != nil {
@@ -306,12 +306,12 @@ func handleAuthorization(ctx *gin.Context, tx *gorm.DB, action int, asset *model
case model.ACTION_UPDATE: case model.ACTION_UPDATE:
eg.Go(func() (err error) { eg.Go(func() (err error) {
pre := &model.Authorization{} pre := &model.Authorization{}
if err = mysql.DB.Where("id=?", auth.GetId()).First(pre).Error; err != nil { if err = dbpkg.DB.Where("id=?", auth.GetId()).First(pre).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) { if !errors.Is(err, gorm.ErrRecordNotFound) {
return return
} }
resourceId := 0 resourceId := 0
if resourceId, err = acl.CreateAcl(ctx, currentUser, conf.RESOURCE_AUTHORIZATION, auth.GetName()); err != nil { if resourceId, err = acl.CreateAcl(ctx, currentUser, config.RESOURCE_AUTHORIZATION, auth.GetName()); err != nil {
return return
} }
auth.ResourceId = resourceId auth.ResourceId = resourceId
@@ -344,7 +344,7 @@ func handleAuthorization(ctx *gin.Context, tx *gorm.DB, action int, asset *model
func getAuthorizations(ctx *gin.Context) (res []*acl.Resource, err error) { func getAuthorizations(ctx *gin.Context) (res []*acl.Resource, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
res, err = acl.GetRoleResources(ctx, currentUser.GetRid(), conf.RESOURCE_AUTHORIZATION) res, err = acl.GetRoleResources(ctx, currentUser.GetRid(), config.RESOURCE_AUTHORIZATION)
if err != nil { if err != nil {
return return
} }
@@ -370,7 +370,7 @@ func getAuthorizationIds(ctx *gin.Context) (authIds []*model.AuthorizationIds, e
return return
} }
err = mysql.DB.Model(authIds).Where("resource_id IN ?", resourceIds).Find(&authIds).Error err = dbpkg.DB.Model(authIds).Where("resource_id IN ?", resourceIds).Find(&authIds).Error
return return
} }
@@ -386,7 +386,7 @@ func hasAuthorization(ctx *gin.Context, sess *gsession.Session) (ok bool) {
} }
if sess.Session.Asset == nil { if sess.Session.Asset == nil {
if err := mysql.DB.Model(sess.Session.Asset).Where("id=?", sess.AssetId).First(&sess.Session.Asset).Error; err != nil { if err := dbpkg.DB.Model(sess.Session.Asset).Where("id=?", sess.AssetId).First(&sess.Session.Asset).Error; err != nil {
return return
} }
} }

View File

@@ -12,10 +12,10 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
) )
var ( var (
@@ -33,7 +33,7 @@ var (
commandDcs = []deleteCheck{ commandDcs = []deleteCheck{
func(ctx *gin.Context, id int) { func(ctx *gin.Context, id int) {
assetName := "" assetName := ""
err := mysql.DB. err := dbpkg.DB.
Model(model.DefaultAsset). Model(model.DefaultAsset).
Select("name"). Select("name").
Where(fmt.Sprintf("JSON_CONTAINS(cmd_ids, '%d')", id)). Where(fmt.Sprintf("JSON_CONTAINS(cmd_ids, '%d')", id)).
@@ -56,7 +56,7 @@ var (
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /command [post] // @Router /command [post]
func (c *Controller) CreateCommand(ctx *gin.Context) { func (c *Controller) CreateCommand(ctx *gin.Context) {
doCreate(ctx, true, &model.Command{}, conf.RESOURCE_COMMAND, commandPreHooks...) doCreate(ctx, true, &model.Command{}, config.RESOURCE_COMMAND, commandPreHooks...)
} }
// DeleteCommand godoc // DeleteCommand godoc
@@ -66,7 +66,7 @@ func (c *Controller) CreateCommand(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /command/:id [delete] // @Router /command/:id [delete]
func (c *Controller) DeleteCommand(ctx *gin.Context) { func (c *Controller) DeleteCommand(ctx *gin.Context) {
doDelete(ctx, true, &model.Command{}, conf.RESOURCE_COMMAND, commandDcs...) doDelete(ctx, true, &model.Command{}, config.RESOURCE_COMMAND, commandDcs...)
} }
// UpdateCommand godoc // UpdateCommand godoc
@@ -77,7 +77,7 @@ func (c *Controller) DeleteCommand(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /command/:id [put] // @Router /command/:id [put]
func (c *Controller) UpdateCommand(ctx *gin.Context) { func (c *Controller) UpdateCommand(ctx *gin.Context) {
doUpdate(ctx, true, &model.Command{}, conf.RESOURCE_COMMAND, commandPreHooks...) doUpdate(ctx, true, &model.Command{}, config.RESOURCE_COMMAND, commandPreHooks...)
} }
// GetCommands godoc // GetCommands godoc
@@ -98,7 +98,7 @@ func (c *Controller) GetCommands(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
db := mysql.DB.Model(&model.Command{}) db := dbpkg.DB.Model(&model.Command{})
db = filterEqual(ctx, db, "id", "enable") db = filterEqual(ctx, db, "id", "enable")
db = filterLike(ctx, db, "name") db = filterLike(ctx, db, "name")
db = filterSearch(ctx, db, "name", "cmd") db = filterSearch(ctx, db, "name", "cmd")
@@ -108,17 +108,17 @@ func (c *Controller) GetCommands(ctx *gin.Context) {
if info && !acl.IsAdmin(currentUser) { if info && !acl.IsAdmin(currentUser) {
//rs := make([]*acl.Resource, 0) //rs := make([]*acl.Resource, 0)
rs, err := acl.GetRoleResources(ctx, currentUser.Acl.Rid, conf.RESOURCE_AUTHORIZATION) rs, err := acl.GetRoleResources(ctx, currentUser.Acl.Rid, config.RESOURCE_AUTHORIZATION)
if err != nil { if err != nil {
handleRemoteErr(ctx, err) handleRemoteErr(ctx, err)
return return
} }
sub := mysql.DB. sub := dbpkg.DB.
Model(&model.Authorization{}). Model(&model.Authorization{}).
Select("DISTINCT asset_id"). Select("DISTINCT asset_id").
Where("resource_id IN ?", lo.Map(rs, func(r *acl.Resource, _ int) int { return r.ResourceId })) Where("resource_id IN ?", lo.Map(rs, func(r *acl.Resource, _ int) int { return r.ResourceId }))
cmdIds := make([]model.Slice[int], 0) cmdIds := make([]model.Slice[int], 0)
if err = mysql.DB. if err = dbpkg.DB.
Model(model.DefaultAsset). Model(model.DefaultAsset).
Select("cmd_ids"). Select("cmd_ids").
Where("id IN (?)", sub). Where("id IN (?)", sub).
@@ -137,5 +137,5 @@ func (c *Controller) GetCommands(ctx *gin.Context) {
db = db.Order("name") db = db.Order("name")
doGet[*model.Command](ctx, !info, db, conf.RESOURCE_COMMAND) doGet[*model.Command](ctx, !info, db, config.RESOURCE_COMMAND)
} }

View File

@@ -9,10 +9,10 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
redis "github.com/veops/oneterm/cache" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/pkg/cache"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
) )
// PostConfig godoc // PostConfig godoc
@@ -37,7 +37,7 @@ func (c *Controller) PostConfig(ctx *gin.Context) {
cfg.CreatorId = currentUser.GetUid() cfg.CreatorId = currentUser.GetUid()
cfg.UpdaterId = currentUser.GetUid() cfg.UpdaterId = currentUser.GetUid()
if err := mysql.DB.Model(cfg).Transaction(func(tx *gorm.DB) error { if err := dbpkg.DB.Model(cfg).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("deleted_at = 0").Delete(&model.Config{}).Error; err != nil { if err := tx.Where("deleted_at = 0").Delete(&model.Config{}).Error; err != nil {
return err return err
} }
@@ -48,7 +48,7 @@ func (c *Controller) PostConfig(ctx *gin.Context) {
} }
model.GlobalConfig.Store(cfg) model.GlobalConfig.Store(cfg)
redis.SetEx(ctx, "config", cfg, time.Hour) cache.SetEx(ctx, "config", cfg, time.Hour)
ctx.JSON(http.StatusOK, defaultHttpResponse) ctx.JSON(http.StatusOK, defaultHttpResponse)
} }
@@ -67,7 +67,7 @@ func (c *Controller) GetConfig(ctx *gin.Context) {
} }
cfg := &model.Config{} cfg := &model.Config{}
if err := mysql.DB.Model(cfg).First(&cfg).Error; err != nil { if err := dbpkg.DB.Model(cfg).First(&cfg).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) { if !errors.Is(err, gorm.ErrRecordNotFound) {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return return

View File

@@ -29,15 +29,15 @@ import (
mysqlDriver "gorm.io/driver/mysql" mysqlDriver "gorm.io/driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/api/guacd" ggateway "github.com/veops/oneterm/internal/gateway"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/guacd"
ggateway "github.com/veops/oneterm/gateway" myi18n "github.com/veops/oneterm/internal/i18n"
myi18n "github.com/veops/oneterm/i18n" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/internal/service"
"github.com/veops/oneterm/model" gsession "github.com/veops/oneterm/internal/session"
gsession "github.com/veops/oneterm/session" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/logger"
) )
var ( var (
@@ -165,7 +165,7 @@ func HandleTerm(sess *gsession.Session) (err error) {
writeErrMsg(sess, "idle timeout\n\n") writeErrMsg(sess, "idle timeout\n\n")
return &ApiError{Code: ErrIdleTimeout, Data: map[string]any{"second": model.GlobalConfig.Load().Timeout}} return &ApiError{Code: ErrIdleTimeout, Data: map[string]any{"second": model.GlobalConfig.Load().Timeout}}
case <-tk1m.C: case <-tk1m.C:
if mysql.DB.Model(asset).Where("id = ?", sess.AssetId).First(asset).Error != nil { if dbpkg.DB.Model(asset).Where("id = ?", sess.AssetId).First(asset).Error != nil {
continue continue
} }
if checkTime(asset.AccessAuth) && (sess.ShareId == 0 || time.Now().Before(sess.ShareEnd)) { if checkTime(asset.AccessAuth) && (sess.ShareId == 0 || time.Now().Before(sess.ShareEnd)) {
@@ -259,7 +259,7 @@ func handleGuacd(sess *gsession.Session) (err error) {
case <-sess.IdleTk.C: case <-sess.IdleTk.C:
return &ApiError{Code: ErrIdleTimeout, Data: map[string]any{"second": model.GlobalConfig.Load().Timeout}} return &ApiError{Code: ErrIdleTimeout, Data: map[string]any{"second": model.GlobalConfig.Load().Timeout}}
case <-tk.C: case <-tk.C:
if mysql.DB.Model(asset).Where("id = ?", sess.AssetId).First(asset).Error != nil { if dbpkg.DB.Model(asset).Where("id = ?", sess.AssetId).First(asset).Error != nil {
continue continue
} }
if checkTime(asset.AccessAuth) && (sess.ShareId == 0 || time.Now().Before(sess.ShareEnd)) { if checkTime(asset.AccessAuth) && (sess.ShareId == 0 || time.Now().Before(sess.ShareEnd)) {
@@ -298,7 +298,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
assetId, accountId := cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id")) assetId, accountId := cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))
asset, account, gateway, err := util.GetAAG(assetId, accountId) asset, account, gateway, err := service.GetAAG(assetId, accountId)
if err != nil { if err != nil {
return return
} }
@@ -330,7 +330,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
if !sess.IsGuacd() { if !sess.IsGuacd() {
w, h := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h")) w, h := cast.ToInt(ctx.Query("w")), cast.ToInt(ctx.Query("h"))
sess.SshParser = gsession.NewParser(sess.SessionId, w, h) sess.SshParser = gsession.NewParser(sess.SessionId, w, h)
if err = mysql.DB.Model(sess.SshParser.Cmds).Where("id IN ? AND enable=?", []int(asset.AccessAuth.CmdIds), true). if err = dbpkg.DB.Model(sess.SshParser.Cmds).Where("id IN ? AND enable=?", []int(asset.AccessAuth.CmdIds), true).
Find(&sess.SshParser.Cmds).Error; err != nil { Find(&sess.SshParser.Cmds).Error; err != nil {
return return
} }
@@ -363,7 +363,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
go connectSsh(ctx, sess, asset, account, gateway) go connectSsh(ctx, sess, asset, account, gateway)
case "redis", "mysql": case "redis", "mysql":
go connectOther(ctx, sess, asset, account, gateway) go connectOther(ctx, sess, asset, account, gateway)
case "vnc", "rdp": case "vnc", "rdp", "telnet":
go connectGuacd(ctx, sess, asset, account, gateway) go connectGuacd(ctx, sess, asset, account, gateway)
default: default:
logger.L().Error("wrong protocol " + sess.Protocol) logger.L().Error("wrong protocol " + sess.Protocol)
@@ -391,12 +391,12 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
} }
}() }()
ip, port, err := util.Proxy(false, sess.SessionId, "ssh", asset, gateway) ip, port, err := service.Proxy(false, sess.SessionId, "ssh", asset, gateway)
if err != nil { if err != nil {
return return
} }
auth, err := util.GetAuth(account) auth, err := service.GetAuth(account)
if err != nil { if err != nil {
return return
} }
@@ -557,7 +557,7 @@ func connectOther(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
}() }()
protocol := strings.Split(sess.Protocol, ":")[0] protocol := strings.Split(sess.Protocol, ":")[0]
ip, port, err := util.Proxy(false, sess.SessionId, protocol, asset, gateway) ip, port, err := service.Proxy(false, sess.SessionId, protocol, asset, gateway)
if err != nil { if err != nil {
return return
} }
@@ -867,7 +867,7 @@ func (c *Controller) ConnectClose(ctx *gin.Context) {
} }
session := &gsession.Session{} session := &gsession.Session{}
err := mysql.DB. err := dbpkg.DB.
Model(session). Model(session).
Where("session_id = ?", ctx.Param("session_id")). Where("session_id = ?", ctx.Param("session_id")).
Where("status = ?", model.SESSIONSTATUS_ONLINE). Where("status = ?", model.SESSIONSTATUS_ONLINE).

View File

@@ -15,12 +15,12 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/remote" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/remote"
) )
var ( var (
@@ -61,7 +61,7 @@ func NewHttpResponseWithData(data any) *HttpResponse {
} }
func doCreate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType string, preHooks ...preHook[T]) (err error) { func doCreate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType string, preHooks ...preHook[T]) (err error) {
defer util.DeleteAllFromCacheDb(ctx, md) defer repository.DeleteAllFromCacheDb(ctx, md)
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
@@ -94,7 +94,7 @@ func doCreate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
md.SetCreatorId(currentUser.Uid) md.SetCreatorId(currentUser.Uid)
md.SetUpdaterId(currentUser.Uid) md.SetUpdaterId(currentUser.Uid)
if err = mysql.DB.Transaction(func(tx *gorm.DB) (err error) { if err = dbpkg.DB.Transaction(func(tx *gorm.DB) (err error) {
if err = tx.Model(md).Create(md).Error; err != nil { if err = tx.Model(md).Create(md).Error; err != nil {
return return
} }
@@ -145,7 +145,7 @@ func doCreate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
} }
func doDelete[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType string, dcs ...deleteCheck) (err error) { func doDelete[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType string, dcs ...deleteCheck) (err error) {
defer util.DeleteAllFromCacheDb(ctx, md) defer repository.DeleteAllFromCacheDb(ctx, md)
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
@@ -155,7 +155,7 @@ func doDelete[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
return return
} }
if err = mysql.DB.Model(md).Where("id = ?", id).First(md).Error; err != nil { if err = dbpkg.DB.Model(md).Where("id = ?", id).First(md).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
ctx.JSON(http.StatusOK, HttpResponse{ ctx.JSON(http.StatusOK, HttpResponse{
Data: map[string]any{ Data: map[string]any{
@@ -187,7 +187,7 @@ func doDelete[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
} }
} }
if err = mysql.DB.Transaction(func(tx *gorm.DB) (err error) { if err = dbpkg.DB.Transaction(func(tx *gorm.DB) (err error) {
switch t := any(md).(type) { switch t := any(md).(type) {
case *model.Asset: case *model.Asset:
if err = handleAuthorization(ctx, tx, model.ACTION_DELETE, t, nil, nil); err != nil { if err = handleAuthorization(ctx, tx, model.ACTION_DELETE, t, nil, nil); err != nil {
@@ -229,7 +229,7 @@ func doDelete[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
} }
func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType string, preHooks ...preHook[T]) (err error) { func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType string, preHooks ...preHook[T]) (err error) {
defer util.DeleteAllFromCacheDb(ctx, md) defer repository.DeleteAllFromCacheDb(ctx, md)
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
@@ -256,7 +256,7 @@ func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
} }
old := getEmpty(md) old := getEmpty(md)
if err = mysql.DB.Model(md).Where("id = ?", id).First(old).Error; err != nil { if err = dbpkg.DB.Model(md).Where("id = ?", id).First(old).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
ctx.JSON(http.StatusOK, defaultHttpResponse) ctx.JSON(http.StatusOK, defaultHttpResponse)
return return
@@ -282,7 +282,7 @@ func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
} }
md.SetId(id) md.SetId(id)
if err = mysql.DB.Transaction(func(tx *gorm.DB) (err error) { if err = dbpkg.DB.Transaction(func(tx *gorm.DB) (err error) {
omits := []string{"resource_id", "created_at", "deleted_at"} omits := []string{"resource_id", "created_at", "deleted_at"}
selects := []string{"*"} selects := []string{"*"}
switch t := any(md).(type) { switch t := any(md).(type) {
@@ -300,10 +300,10 @@ func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
} }
} }
if err = mysql.DB.Select(selects).Omit(omits...).Save(md).Error; err != nil { if err = dbpkg.DB.Select(selects).Omit(omits...).Save(md).Error; err != nil {
return return
} }
err = mysql.DB.Create(&model.History{ err = dbpkg.DB.Create(&model.History{
RemoteIp: ctx.ClientIP(), RemoteIp: ctx.ClientIP(),
Type: md.TableName(), Type: md.TableName(),
TargetId: md.GetId(), TargetId: md.GetId(),
@@ -419,7 +419,7 @@ func filterSearch(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
return db return db
} }
d := mysql.DB d := dbpkg.DB
for _, f := range fields { for _, f := range fields {
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q)) d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
} }
@@ -459,7 +459,7 @@ func filterEqual(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
} }
func filterLike(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB { func filterLike(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
likes := false likes := false
d := mysql.DB d := dbpkg.DB
for _, f := range fields { for _, f := range fields {
if q, ok := ctx.GetQuery(f); ok && q != "" { if q, ok := ctx.GetQuery(f); ok && q != "" {
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q)) d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
@@ -507,10 +507,10 @@ func hasPerm[T model.Model](ctx context.Context, md T, resourceTypeName, action
} }
if len(pids) > 0 { if len(pids) > 0 {
res, _ := acl.GetRoleResources(ctx, currentUser.GetRid(), conf.RESOURCE_NODE) res, _ := acl.GetRoleResources(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
resId2perms := lo.SliceToMap(res, func(r *acl.Resource) (int, []string) { return r.ResourceId, r.Permissions }) resId2perms := lo.SliceToMap(res, func(r *acl.Resource) (int, []string) { return r.ResourceId, r.Permissions })
resId2perms, _ = handleSelfChildPerms(ctx, resId2perms) resId2perms, _ = handleSelfChildPerms(ctx, resId2perms)
nodes, _ := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, _ := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
id2resId := lo.SliceToMap(nodes, func(n *model.Node) (int, int) { return n.Id, n.ResourceId }) id2resId := lo.SliceToMap(nodes, func(n *model.Node) (int, int) { return n.Id, n.ResourceId })
if lo.ContainsBy(pids, func(pid int) bool { return lo.Contains(resId2perms[id2resId[pid]], action) }) { if lo.ContainsBy(pids, func(pid int) bool { return lo.Contains(resId2perms[id2resId[pid]], action) }) {
return true return true
@@ -527,7 +527,7 @@ func handlePermissions[T any](ctx *gin.Context, data []T, resourceTypeName strin
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
if !lo.Contains(conf.PermResource, resourceTypeName) { if !lo.Contains(config.PermResource, resourceTypeName) {
return return
} }
@@ -546,7 +546,7 @@ func handlePermissions[T any](ctx *gin.Context, data []T, resourceTypeName strin
return return
} }
case []*model.Asset: case []*model.Asset:
res, err = acl.GetRoleResources(ctx, currentUser.GetRid(), conf.RESOURCE_NODE) res, err = acl.GetRoleResources(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
if err != nil { if err != nil {
handleRemoteErr(ctx, err) handleRemoteErr(ctx, err)
return return
@@ -556,12 +556,16 @@ func handlePermissions[T any](ctx *gin.Context, data []T, resourceTypeName strin
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return return
} }
nodeId2ResId := make(map[int]int) var nodeId2ResId map[int]int
if nodeId2ResId, err = getNodeId2ResId(ctx); err != nil { nodeId2ResId, err = getNodeId2ResId(ctx)
if err != nil {
return return
} }
for _, d := range ds { for _, d := range ds {
resId2perms[d.GetResourceId()] = append(resId2perms[d.GetResourceId()], nodeResId2perms[nodeId2ResId[d.ParentId]]...) resId2perms[d.GetResourceId()] = append(
resId2perms[d.GetResourceId()],
nodeResId2perms[nodeId2ResId[d.ParentId]]...,
)
} }
} }
@@ -609,7 +613,7 @@ func handleAcl[T any](ctx *gin.Context, dbFind *gorm.DB, resourceType string) (d
func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) { func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -619,12 +623,12 @@ func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB
return return
} }
assetResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), conf.RESOURCE_ASSET) assetResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
if err != nil { if err != nil {
return return
} }
assets := make([]*model.AssetIdPid, 0) assets := make([]*model.AssetIdPid, 0)
if err = mysql.DB.Model(assets).Where("resource_id IN ?", assetResIds).Find(&assets).Error; err != nil { if err = dbpkg.DB.Model(assets).Where("resource_id IN ?", assetResIds).Find(&assets).Error; err != nil {
return return
} }
ids = append(ids, lo.Map(assets, func(a *model.AssetIdPid, _ int) int { return a.ParentId })...) ids = append(ids, lo.Map(assets, func(a *model.AssetIdPid, _ int) int { return a.ParentId })...)
@@ -642,11 +646,11 @@ func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB
func handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) { func handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
nodeResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), conf.RESOURCE_NODE) nodeResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
if err != nil { if err != nil {
return return
} }
@@ -656,7 +660,7 @@ func handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.D
return return
} }
d := mysql.DB.Where("resource_id IN ?", resIds).Or("parent_id IN?", nodeIds) d := dbpkg.DB.Where("resource_id IN ?", resIds).Or("parent_id IN?", nodeIds)
db = dbFind.Where(d) db = dbFind.Where(d)
@@ -666,18 +670,18 @@ func handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.D
func handleAccountIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) { func handleAccountIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
assetResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), conf.RESOURCE_ASSET) assetResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
if err != nil { if err != nil {
return return
} }
t, _ := handleAssetIds(ctx, mysql.DB.Model(model.DefaultAsset), assetResIds) t, _ := handleAssetIds(ctx, dbpkg.DB.Model(model.DefaultAsset), assetResIds)
ss := make([]model.Slice[string], 0) ss := make([]model.Slice[string], 0)
if err = t.Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil { if err = t.Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
return return
} }
ids := lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) })) ids := lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) }))
d := mysql.DB.Where("resource_id IN ?", resIds).Or("id IN ?", ids) d := dbpkg.DB.Where("resource_id IN ?", resIds).Or("id IN ?", ids)
db = dbFind.Where(d) db = dbFind.Where(d)

View File

@@ -7,7 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
myi18n "github.com/veops/oneterm/i18n" myi18n "github.com/veops/oneterm/internal/i18n"
) )
const ( const (

View File

@@ -13,12 +13,12 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/api/file" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/service"
"github.com/veops/oneterm/logger" gsession "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
gsession "github.com/veops/oneterm/session" "github.com/veops/oneterm/pkg/logger"
) )
// GetFileHistory godoc // GetFileHistory godoc
@@ -37,7 +37,7 @@ import (
// @Success 200 {object} HttpResponse{data=ListData{list=[]model.Session}} // @Success 200 {object} HttpResponse{data=ListData{list=[]model.Session}}
// @Router /file/history [get] // @Router /file/history [get]
func (c *Controller) GetFileHistory(ctx *gin.Context) { func (c *Controller) GetFileHistory(ctx *gin.Context) {
db := mysql.DB.Model(&model.FileHistory{}) db := dbpkg.DB.Model(&model.FileHistory{})
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
if !acl.IsAdmin(currentUser) { if !acl.IsAdmin(currentUser) {
db = db.Where("uid = ?", currentUser.Uid) db = db.Where("uid = ?", currentUser.Uid)
@@ -73,7 +73,7 @@ func (c *Controller) FileLS(ctx *gin.Context) {
return return
} }
cli, err := file.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))) cli, err := service.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id")))
if err != nil { if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return return
@@ -87,7 +87,7 @@ func (c *Controller) FileLS(ctx *gin.Context) {
res := &ListData{ res := &ListData{
Count: int64(len(info)), Count: int64(len(info)),
List: lo.Map(info, func(f fs.FileInfo, _ int) any { List: lo.Map(info, func(f fs.FileInfo, _ int) any {
return &file.FileInfo{ return &service.FileInfo{
Name: f.Name(), Name: f.Name(),
IsDir: f.IsDir(), IsDir: f.IsDir(),
Size: f.Size(), Size: f.Size(),
@@ -121,7 +121,7 @@ func (c *Controller) FileMkdir(ctx *gin.Context) {
return return
} }
cli, err := file.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))) cli, err := service.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id")))
if err != nil { if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{}})
return return
@@ -140,7 +140,7 @@ func (c *Controller) FileMkdir(ctx *gin.Context) {
Action: model.FILE_ACTION_MKDIR, Action: model.FILE_ACTION_MKDIR,
Dir: ctx.Query("dir"), Dir: ctx.Query("dir"),
} }
if err = mysql.DB.Model(h).Create(h).Error; err != nil { if err = dbpkg.DB.Model(h).Create(h).Error; err != nil {
logger.L().Error("record mkdir failed", zap.Error(err), zap.Any("history", h)) logger.L().Error("record mkdir failed", zap.Error(err), zap.Any("history", h))
} }
ctx.JSON(http.StatusOK, defaultHttpResponse) ctx.JSON(http.StatusOK, defaultHttpResponse)
@@ -181,7 +181,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
return return
} }
cli, err := file.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))) cli, err := service.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id")))
if err != nil { if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{}})
return return
@@ -206,7 +206,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
Dir: ctx.Query("dir"), Dir: ctx.Query("dir"),
Filename: fh.Filename, Filename: fh.Filename,
} }
if err = mysql.DB.Model(h).Create(h).Error; err != nil { if err = dbpkg.DB.Model(h).Create(h).Error; err != nil {
logger.L().Error("record upload failed", zap.Error(err), zap.Any("history", h)) logger.L().Error("record upload failed", zap.Error(err), zap.Any("history", h))
} }
@@ -238,7 +238,7 @@ func (c *Controller) FileDownload(ctx *gin.Context) {
return return
} }
cli, err := file.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))) cli, err := service.GetFileManager().GetFileClient(cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id")))
if err != nil { if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{}})
return return
@@ -268,7 +268,7 @@ func (c *Controller) FileDownload(ctx *gin.Context) {
Filename: ctx.Query("filename"), Filename: ctx.Query("filename"),
} }
if err = mysql.DB.Model(h).Create(h).Error; err != nil { if err = dbpkg.DB.Model(h).Create(h).Error; err != nil {
logger.L().Error("record download failed", zap.Error(err), zap.Any("history", h)) logger.L().Error("record download failed", zap.Error(err), zap.Any("history", h))
} }
} }

View File

@@ -11,11 +11,11 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/utils"
) )
var ( var (
@@ -38,15 +38,15 @@ var (
} }
}, },
func(ctx *gin.Context, data *model.Gateway) { func(ctx *gin.Context, data *model.Gateway) {
data.Password = util.EncryptAES(data.Password) data.Password = utils.EncryptAES(data.Password)
data.Pk = util.EncryptAES(data.Pk) data.Pk = utils.EncryptAES(data.Pk)
data.Phrase = util.EncryptAES(data.Phrase) data.Phrase = utils.EncryptAES(data.Phrase)
}, },
} }
gatewayPostHooks = []postHook[*model.Gateway]{ gatewayPostHooks = []postHook[*model.Gateway]{
func(ctx *gin.Context, data []*model.Gateway) { func(ctx *gin.Context, data []*model.Gateway) {
post := make([]*model.GatewayCount, 0) post := make([]*model.GatewayCount, 0)
if err := mysql.DB. if err := dbpkg.DB.
Model(model.DefaultAsset). Model(model.DefaultAsset).
Select("gateway_id AS id, COUNT(*) AS count"). Select("gateway_id AS id, COUNT(*) AS count").
Where("gateway_id IN ?", lo.Map(data, func(d *model.Gateway, _ int) int { return d.Id })). Where("gateway_id IN ?", lo.Map(data, func(d *model.Gateway, _ int) int { return d.Id })).
@@ -62,16 +62,16 @@ var (
}, },
func(ctx *gin.Context, data []*model.Gateway) { func(ctx *gin.Context, data []*model.Gateway) {
for _, d := range data { for _, d := range data {
d.Password = util.DecryptAES(d.Password) d.Password = utils.DecryptAES(d.Password)
d.Pk = util.DecryptAES(d.Pk) d.Pk = utils.DecryptAES(d.Pk)
d.Phrase = util.DecryptAES(d.Phrase) d.Phrase = utils.DecryptAES(d.Phrase)
} }
}, },
} }
gatewayDcs = []deleteCheck{ gatewayDcs = []deleteCheck{
func(ctx *gin.Context, id int) { func(ctx *gin.Context, id int) {
assetName := "" assetName := ""
err := mysql.DB. err := dbpkg.DB.
Model(model.DefaultAsset). Model(model.DefaultAsset).
Select("name"). Select("name").
Where("gateway_id = ?", id). Where("gateway_id = ?", id).
@@ -94,7 +94,7 @@ var (
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /gateway [post] // @Router /gateway [post]
func (c *Controller) CreateGateway(ctx *gin.Context) { func (c *Controller) CreateGateway(ctx *gin.Context) {
doCreate(ctx, true, &model.Gateway{}, conf.RESOURCE_GATEWAY, gatewayPreHooks...) doCreate(ctx, true, &model.Gateway{}, config.RESOURCE_GATEWAY, gatewayPreHooks...)
} }
// DeleteGateway godoc // DeleteGateway godoc
@@ -104,7 +104,7 @@ func (c *Controller) CreateGateway(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /gateway/:id [delete] // @Router /gateway/:id [delete]
func (c *Controller) DeleteGateway(ctx *gin.Context) { func (c *Controller) DeleteGateway(ctx *gin.Context) {
doDelete(ctx, true, &model.Gateway{}, conf.RESOURCE_GATEWAY, gatewayDcs...) doDelete(ctx, true, &model.Gateway{}, config.RESOURCE_GATEWAY, gatewayDcs...)
} }
// UpdateGateway godoc // UpdateGateway godoc
@@ -115,7 +115,7 @@ func (c *Controller) DeleteGateway(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /gateway/:id [put] // @Router /gateway/:id [put]
func (c *Controller) UpdateGateway(ctx *gin.Context) { func (c *Controller) UpdateGateway(ctx *gin.Context) {
doUpdate(ctx, true, &model.Gateway{}, conf.RESOURCE_GATEWAY, gatewayPreHooks...) doUpdate(ctx, true, &model.Gateway{}, config.RESOURCE_GATEWAY, gatewayPreHooks...)
} }
// GetGateways godoc // GetGateways godoc
@@ -135,7 +135,7 @@ func (c *Controller) GetGateways(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
db := mysql.DB.Model(model.DefaultGateway) db := dbpkg.DB.Model(model.DefaultGateway)
db = filterEqual(ctx, db, "id", "type") db = filterEqual(ctx, db, "id", "type")
db = filterLike(ctx, db, "name") db = filterLike(ctx, db, "name")
db = filterSearch(ctx, db, "name", "host", "account", "port") db = filterSearch(ctx, db, "name", "host", "account", "port")
@@ -149,7 +149,7 @@ func (c *Controller) GetGateways(ctx *gin.Context) {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return return
} }
sub := mysql.DB. sub := dbpkg.DB.
Model(model.DefaultAsset). Model(model.DefaultAsset).
Select("DISTINCT gateway_id"). Select("DISTINCT gateway_id").
Where("asset_id IN ?", assetIds) Where("asset_id IN ?", assetIds)
@@ -159,5 +159,5 @@ func (c *Controller) GetGateways(ctx *gin.Context) {
db = db.Order("name") db = db.Order("name")
doGet(ctx, !info, db, conf.RESOURCE_GATEWAY, gatewayPostHooks...) doGet(ctx, !info, db, config.RESOURCE_GATEWAY, gatewayPostHooks...)
} }

View File

@@ -6,9 +6,9 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
mysql "github.com/veops/oneterm/db" myi18n "github.com/veops/oneterm/internal/i18n"
myi18n "github.com/veops/oneterm/i18n" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
) )
// GetHistories godoc // GetHistories godoc
@@ -26,7 +26,7 @@ import (
// @Success 200 {object} HttpResponse{data=ListData{list=[]model.History}} // @Success 200 {object} HttpResponse{data=ListData{list=[]model.History}}
// @Router /history [get] // @Router /history [get]
func (c *Controller) GetHistories(ctx *gin.Context) { func (c *Controller) GetHistories(ctx *gin.Context) {
db := mysql.DB.Model(&model.History{}) db := dbpkg.DB.Model(&model.History{})
db = filterSearch(ctx, db, "old", "new") db = filterSearch(ctx, db, "old", "new")
db, err := filterStartEnd(ctx, db) db, err := filterStartEnd(ctx, db)
if err != nil { if err != nil {

View File

@@ -13,13 +13,13 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
redis "github.com/veops/oneterm/cache" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/repository"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/pkg/cache"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/logger"
) )
const ( const (
@@ -40,8 +40,8 @@ var (
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /node [post] // @Router /node [post]
func (c *Controller) CreateNode(ctx *gin.Context) { func (c *Controller) CreateNode(ctx *gin.Context) {
redis.RC.Del(ctx, kFmtAllNodes) cache.RC.Del(ctx, kFmtAllNodes)
doCreate(ctx, true, &model.Node{}, conf.RESOURCE_NODE) doCreate(ctx, true, &model.Node{}, config.RESOURCE_NODE)
} }
// DeleteNode godoc // DeleteNode godoc
@@ -51,8 +51,8 @@ func (c *Controller) CreateNode(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /node/:id [delete] // @Router /node/:id [delete]
func (c *Controller) DeleteNode(ctx *gin.Context) { func (c *Controller) DeleteNode(ctx *gin.Context) {
redis.RC.Del(ctx, kFmtAllNodes) cache.RC.Del(ctx, kFmtAllNodes)
doDelete(ctx, true, &model.Node{}, conf.RESOURCE_NODE, nodeDcs...) doDelete(ctx, true, &model.Node{}, config.RESOURCE_NODE, nodeDcs...)
} }
// UpdateNode godoc // UpdateNode godoc
@@ -63,8 +63,8 @@ func (c *Controller) DeleteNode(ctx *gin.Context) {
// @Success 200 {object} HttpResponse // @Success 200 {object} HttpResponse
// @Router /node/:id [put] // @Router /node/:id [put]
func (c *Controller) UpdateNode(ctx *gin.Context) { func (c *Controller) UpdateNode(ctx *gin.Context) {
redis.RC.Del(ctx, kFmtAllNodes) cache.RC.Del(ctx, kFmtAllNodes)
doUpdate(ctx, true, &model.Node{}, conf.RESOURCE_NODE, nodePreHooks...) doUpdate(ctx, true, &model.Node{}, config.RESOURCE_NODE, nodePreHooks...)
} }
// GetNodes godoc // GetNodes godoc
@@ -85,7 +85,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
db := mysql.DB.Model(model.DefaultNode) db := dbpkg.DB.Model(model.DefaultNode)
db = filterEqual(ctx, db, "id", "parent_id") db = filterEqual(ctx, db, "id", "parent_id")
db = filterLike(ctx, db, "name") db = filterLike(ctx, db, "name")
@@ -123,12 +123,12 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
} }
} }
doGet(ctx, !info, db, conf.RESOURCE_NODE, nodePostHooks...) doGet(ctx, !info, db, config.RESOURCE_NODE, nodePostHooks...)
} }
func nodePreHookCheckCycle(ctx *gin.Context, data *model.Node) { func nodePreHookCheckCycle(ctx *gin.Context, data *model.Node) {
nodes := make([]*model.Node, 0) nodes := make([]*model.Node, 0)
err := mysql.DB.Model(model.DefaultNode).Find(&nodes).Error err := dbpkg.DB.Model(model.DefaultNode).Find(&nodes).Error
g := make(map[int][]int) g := make(map[int][]int)
for _, n := range nodes { for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id) g[n.ParentId] = append(g[n.ParentId], n.Id)
@@ -150,7 +150,7 @@ func nodePreHookCheckCycle(ctx *gin.Context, data *model.Node) {
func nodePostHookCountAsset(ctx *gin.Context, data []*model.Node) { func nodePostHookCountAsset(ctx *gin.Context, data []*model.Node) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
assets := make([]*model.AssetIdPid, 0) assets := make([]*model.AssetIdPid, 0)
db := mysql.DB.Model(model.DefaultAsset) db := dbpkg.DB.Model(model.DefaultAsset)
if !acl.IsAdmin(currentUser) { if !acl.IsAdmin(currentUser) {
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
if info { if info {
@@ -160,7 +160,7 @@ func nodePostHookCountAsset(ctx *gin.Context, data []*model.Node) {
} }
db = db.Where("id IN ?", assetIds) db = db.Where("id IN ?", assetIds)
} else { } else {
assetResId, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), conf.RESOURCE_ASSET) assetResId, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
if err != nil { if err != nil {
return return
} }
@@ -174,7 +174,7 @@ func nodePostHookCountAsset(ctx *gin.Context, data []*model.Node) {
logger.L().Error("node posthookfailed asset count", zap.Error(err)) logger.L().Error("node posthookfailed asset count", zap.Error(err))
return return
} }
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
logger.L().Error("node posthookfailed node", zap.Error(err)) logger.L().Error("node posthookfailed node", zap.Error(err))
return return
@@ -203,32 +203,37 @@ func nodePostHookCountAsset(ctx *gin.Context, data []*model.Node) {
func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) { func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) {
info := cast.ToBool(ctx.Query("info")) info := cast.ToBool(ctx.Query("info"))
ps := make(map[int]bool, 0)
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
nodes, _ := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, _ := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if acl.IsAdmin(currentUser) { if acl.IsAdmin(currentUser) {
ps = lo.SliceToMap(nodes, func(n *model.Node) (int, bool) { return n.ParentId, true }) ps := lo.SliceToMap(nodes, func(n *model.Node) (int, bool) { return n.ParentId, true })
for _, n := range data {
n.HasChild = ps[n.Id]
}
} else { } else {
assets, _ := util.GetAllFromCacheDb(ctx, model.DefaultAsset) assets, _ := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
if info { if info {
assetIds, _ := GetAssetIdsByAuthorization(ctx) assetIds, _ := GetAssetIdsByAuthorization(ctx)
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetIds, a.Id) }) assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetIds, a.Id) })
pids := lo.Map(assets, func(a *model.Asset, _ int) int { return a.ParentId }) pids := lo.Map(assets, func(a *model.Asset, _ int) int { return a.ParentId })
pids, _ = handleSelfParent(ctx, pids...) pids, _ = handleSelfParent(ctx, pids...)
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(pids, n.Id) }) nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(pids, n.Id) })
ps = lo.SliceToMap(nodes, func(a *model.Node) (int, bool) { return a.ParentId, true }) ps := lo.SliceToMap(nodes, func(a *model.Node) (int, bool) { return a.ParentId, true })
for _, n := range data {
n.HasChild = ps[n.Id]
}
} else { } else {
var assetResIds, nodeResIds, pids, nids []int var assetResIds, nodeResIds, pids, nids []int
eg := errgroup.Group{} eg := errgroup.Group{}
eg.Go(func() (err error) { eg.Go(func() (err error) {
assetResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), conf.RESOURCE_ASSET) assetResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetResIds, a.ResourceId) }) assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetResIds, a.ResourceId) })
pids = lo.Map(assets, func(n *model.Asset, _ int) int { return n.ParentId }) pids = lo.Map(assets, func(n *model.Asset, _ int) int { return n.ParentId })
pids, _ = handleSelfParent(ctx, pids...) pids, _ = handleSelfParent(ctx, pids...)
return return
}) })
eg.Go(func() (err error) { eg.Go(func() (err error) {
nodeResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), conf.RESOURCE_NODE) nodeResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
ns := lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) }) ns := lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
nids, _ = handleSelfChild(ctx, lo.Map(ns, func(n *model.Node, _ int) int { return n.Id })...) nids, _ = handleSelfChild(ctx, lo.Map(ns, func(n *model.Node, _ int) int { return n.Id })...)
nids, _ = handleSelfParent(ctx, nids...) nids, _ = handleSelfParent(ctx, nids...)
@@ -236,18 +241,18 @@ func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) {
}) })
eg.Wait() eg.Wait()
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(pids, n.Id) || lo.Contains(nids, n.Id) }) nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(pids, n.Id) || lo.Contains(nids, n.Id) })
ps = lo.SliceToMap(nodes, func(n *model.Node) (int, bool) { return n.ParentId, true }) ps := lo.SliceToMap(nodes, func(n *model.Node) (int, bool) { return n.ParentId, true })
for _, n := range data {
n.HasChild = ps[n.Id]
}
} }
} }
for _, n := range data {
n.HasChild = ps[n.Id]
}
} }
func nodeDelHook(ctx *gin.Context, id int) { func nodeDelHook(ctx *gin.Context, id int) {
noChild := true noChild := true
noChild = noChild && errors.Is(mysql.DB.Model(model.DefaultNode).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound) noChild = noChild && errors.Is(dbpkg.DB.Model(model.DefaultNode).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound)
noChild = noChild && errors.Is(mysql.DB.Model(model.DefaultAsset).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound) noChild = noChild && errors.Is(dbpkg.DB.Model(model.DefaultAsset).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound)
if noChild { if noChild {
return return
} }
@@ -257,7 +262,7 @@ func nodeDelHook(ctx *gin.Context, id int) {
} }
func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) { func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -283,7 +288,7 @@ func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
} }
func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) { func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -312,7 +317,7 @@ func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
} }
func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) { func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -328,7 +333,7 @@ func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
} }
func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) { func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -348,7 +353,7 @@ func GetNodeIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
if err != nil { if err != nil {
return return
} }
assets, _ := util.GetAllFromCacheDb(ctx, model.DefaultAsset) assets, _ := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetIds, a.Id) }) assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetIds, a.Id) })
ids = lo.Uniq(lo.Map(assets, func(a *model.Asset, _ int) int { return a.ParentId })) ids = lo.Uniq(lo.Map(assets, func(a *model.Asset, _ int) int { return a.ParentId }))
@@ -356,7 +361,7 @@ func GetNodeIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
} }
func handleSelfChildPerms(ctx context.Context, id2perms map[int][]string) (res map[int][]string, err error) { func handleSelfChildPerms(ctx context.Context, id2perms map[int][]string) (res map[int][]string, err error) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }
@@ -382,7 +387,7 @@ func handleSelfChildPerms(ctx context.Context, id2perms map[int][]string) (res m
} }
func getNodeId2ResId(ctx context.Context) (resid2ids map[int]int, err error) { func getNodeId2ResId(ctx context.Context) (resid2ids map[int]int, err error) {
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil { if err != nil {
return return
} }

View File

@@ -7,10 +7,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/utils"
) )
var ( var (
@@ -23,7 +23,7 @@ var (
} }
}, },
func(ctx *gin.Context, data *model.PublicKey) { func(ctx *gin.Context, data *model.PublicKey) {
data.Pk = util.EncryptAES(data.Pk) data.Pk = utils.EncryptAES(data.Pk)
}, },
func(ctx *gin.Context, data *model.PublicKey) { func(ctx *gin.Context, data *model.PublicKey) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
@@ -34,7 +34,7 @@ var (
publicKeyPostHooks = []postHook[*model.PublicKey]{ publicKeyPostHooks = []postHook[*model.PublicKey]{
func(ctx *gin.Context, data []*model.PublicKey) { func(ctx *gin.Context, data []*model.PublicKey) {
for _, d := range data { for _, d := range data {
d.Pk = util.DecryptAES(d.Pk) d.Pk = utils.DecryptAES(d.Pk)
} }
}, },
} }
@@ -84,7 +84,7 @@ func (c *Controller) UpdatePublicKey(ctx *gin.Context) {
func (c *Controller) GetPublicKeys(ctx *gin.Context) { func (c *Controller) GetPublicKeys(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
db := mysql.DB.Model(&model.PublicKey{}) db := dbpkg.DB.Model(&model.PublicKey{})
db = filterSearch(ctx, db, "name", "mac") db = filterSearch(ctx, db, "name", "mac")
db = filterEqual(ctx, db, "id") db = filterEqual(ctx, db, "id")
db = filterLike(ctx, db, "name") db = filterLike(ctx, db, "name")

View File

@@ -12,10 +12,10 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/logger" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/logger"
) )
var ( var (
@@ -26,7 +26,7 @@ var (
return return
} }
post := make([]*model.CmdCount, 0) post := make([]*model.CmdCount, 0)
if err := mysql.DB. if err := dbpkg.DB.
Model(&model.SessionCmd{}). Model(&model.SessionCmd{}).
Select("session_id, COUNT(*) AS count"). Select("session_id, COUNT(*) AS count").
Where("session_id IN ?", sessionIds). Where("session_id IN ?", sessionIds).
@@ -66,7 +66,7 @@ func (c *Controller) CreateSessionCmd(ctx *gin.Context) {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument, Data: map[string]any{"err": err}})
return return
} }
if err := mysql.DB. if err := dbpkg.DB.
Create(data). Create(data).
Error; err != nil { Error; err != nil {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
@@ -91,7 +91,7 @@ func (c *Controller) CreateSessionCmd(ctx *gin.Context) {
// @Success 200 {object} HttpResponse{data=ListData{list=[]model.Session}} // @Success 200 {object} HttpResponse{data=ListData{list=[]model.Session}}
// @Router /session [get] // @Router /session [get]
func (c *Controller) GetSessions(ctx *gin.Context) { func (c *Controller) GetSessions(ctx *gin.Context) {
db := mysql.DB.Model(model.DefaultSession) db := dbpkg.DB.Model(model.DefaultSession)
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
if !acl.IsAdmin(currentUser) { if !acl.IsAdmin(currentUser) {
db = db.Where("uid = ?", currentUser.Uid) db = db.Where("uid = ?", currentUser.Uid)
@@ -116,7 +116,7 @@ func (c *Controller) GetSessions(ctx *gin.Context) {
// @Success 200 {object} HttpResponse{data=ListData{list=[]model.SessionCmd}} // @Success 200 {object} HttpResponse{data=ListData{list=[]model.SessionCmd}}
// @Router /session/:session_id/cmd [get] // @Router /session/:session_id/cmd [get]
func (c *Controller) GetSessionCmds(ctx *gin.Context) { func (c *Controller) GetSessionCmds(ctx *gin.Context) {
db := mysql.DB.Model(&model.SessionCmd{}) db := dbpkg.DB.Model(&model.SessionCmd{})
db = db.Where("session_id = ?", ctx.Param("session_id")) db = db.Where("session_id = ?", ctx.Param("session_id"))
db = filterSearch(ctx, db, "cmd", "result") db = filterSearch(ctx, db, "cmd", "result")
@@ -130,7 +130,7 @@ func (c *Controller) GetSessionCmds(ctx *gin.Context) {
// @Router /session/option/asset [get] // @Router /session/option/asset [get]
func (c *Controller) GetSessionOptionAsset(ctx *gin.Context) { func (c *Controller) GetSessionOptionAsset(ctx *gin.Context) {
opts := make([]*model.SessionOptionAsset, 0) opts := make([]*model.SessionOptionAsset, 0)
if err := mysql.DB. if err := dbpkg.DB.
Model(model.DefaultAsset). Model(model.DefaultAsset).
Select("id, name"). Select("id, name").
Find(&opts). Find(&opts).
@@ -149,7 +149,7 @@ func (c *Controller) GetSessionOptionAsset(ctx *gin.Context) {
// @Router /session/option/clientip [get] // @Router /session/option/clientip [get]
func (c *Controller) GetSessionOptionClientIp(ctx *gin.Context) { func (c *Controller) GetSessionOptionClientIp(ctx *gin.Context) {
opts := make([]string, 0) opts := make([]string, 0)
if err := mysql.DB. if err := dbpkg.DB.
Model(model.DefaultSession). Model(model.DefaultSession).
Distinct("client_ip"). Distinct("client_ip").
Find(&opts). Find(&opts).
@@ -200,7 +200,7 @@ func (c *Controller) CreateSessionReplay(ctx *gin.Context) {
func (c *Controller) GetSessionReplay(ctx *gin.Context) { func (c *Controller) GetSessionReplay(ctx *gin.Context) {
sessionId := ctx.Param("session_id") sessionId := ctx.Param("session_id")
session := &model.Session{} session := &model.Session{}
if err := mysql.DB.Model(session).Where("session_id = ?", sessionId).First(session).Error; err != nil { if err := dbpkg.DB.Model(session).Where("session_id = ?", sessionId).First(session).Error; err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
} }
filename := sessionId filename := sessionId

View File

@@ -12,9 +12,9 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
) )
// CreateShare godoc // CreateShare godoc
@@ -46,7 +46,7 @@ func (c *Controller) CreateShare(ctx *gin.Context) {
s.Uuid = uuid.New().String() s.Uuid = uuid.New().String()
return s.Uuid return s.Uuid
}) })
if err := mysql.DB.Create(&shares).Error; err != nil { if err := dbpkg.DB.Create(&shares).Error; err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return return
} }
@@ -64,7 +64,7 @@ func (c *Controller) DeleteShare(ctx *gin.Context) {
Id: cast.ToInt(ctx.Param("id")), Id: cast.ToInt(ctx.Param("id")),
} }
if err := mysql.DB.Model(share).Where("id=?", share.Id).First(share); err != nil { if err := dbpkg.DB.Model(share).Where("id=?", share.Id).First(share); err != nil {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument, Data: map[string]any{"err": err}}) ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument, Data: map[string]any{"err": err}})
return return
} }
@@ -91,7 +91,7 @@ func (c *Controller) DeleteShare(ctx *gin.Context) {
func (c *Controller) GetShare(ctx *gin.Context) { func (c *Controller) GetShare(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx) currentUser, _ := acl.GetSessionFromCtx(ctx)
db := mysql.DB.Model(&model.Share{}) db := dbpkg.DB.Model(&model.Share{})
db = filterSearch(ctx, db) db = filterSearch(ctx, db)
db, err := filterStartEnd(ctx, db) db, err := filterStartEnd(ctx, db)
if err != nil { if err != nil {
@@ -121,7 +121,7 @@ func (c *Controller) GetShare(ctx *gin.Context) {
// @Router /share/connect/:uuid [get] // @Router /share/connect/:uuid [get]
func (c *Controller) ConnectShare(ctx *gin.Context) { func (c *Controller) ConnectShare(ctx *gin.Context) {
share := &model.Share{} share := &model.Share{}
if err := mysql.DB.Transaction(func(tx *gorm.DB) (err error) { if err := dbpkg.DB.Transaction(func(tx *gorm.DB) (err error) {
if err = tx.Where("uuid=?", ctx.Param("uuid")).First(share).Error; err != nil { if err = tx.Where("uuid=?", ctx.Param("uuid")).First(share).Error; err != nil {
return return
} }

View File

@@ -10,11 +10,11 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
redis "github.com/veops/oneterm/cache" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/cache"
"github.com/veops/oneterm/util" dbpkg "github.com/veops/oneterm/pkg/db"
) )
// StatAssetType godoc // StatAssetType godoc
@@ -25,7 +25,7 @@ import (
func (c *Controller) StatAssetType(ctx *gin.Context) { func (c *Controller) StatAssetType(ctx *gin.Context) {
stat := make([]*model.StatAssetType, 0) stat := make([]*model.StatAssetType, 0)
key := "stat-assettype" key := "stat-assettype"
if redis.Get(ctx, key, stat) == nil { if cache.Get(ctx, key, stat) == nil {
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
return return
} }
@@ -35,7 +35,7 @@ func (c *Controller) StatAssetType(ctx *gin.Context) {
ctx.AbortWithError(http.StatusInternalServerError, err) ctx.AbortWithError(http.StatusInternalServerError, err)
return return
} }
if err = mysql.DB. if err = dbpkg.DB.
Model(stat). Model(stat).
Where("parent_id = 0"). Where("parent_id = 0").
Find(&stat). Find(&stat).
@@ -47,7 +47,7 @@ func (c *Controller) StatAssetType(ctx *gin.Context) {
s.Count = m[s.Id] s.Count = m[s.Id]
} }
redis.SetEx(ctx, key, stat, time.Minute) cache.SetEx(ctx, key, stat, time.Minute)
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
} }
@@ -60,14 +60,14 @@ func (c *Controller) StatAssetType(ctx *gin.Context) {
func (c *Controller) StatCount(ctx *gin.Context) { func (c *Controller) StatCount(ctx *gin.Context) {
stat := &model.StatCount{} stat := &model.StatCount{}
key := "stat-count" key := "stat-count"
if redis.Get(ctx, key, stat) == nil { if cache.Get(ctx, key, stat) == nil {
ctx.JSON(http.StatusOK, NewHttpResponseWithData(stat)) ctx.JSON(http.StatusOK, NewHttpResponseWithData(stat))
return return
} }
eg := &errgroup.Group{} eg := &errgroup.Group{}
eg.Go(func() error { eg.Go(func() error {
return mysql.DB. return dbpkg.DB.
Model(model.DefaultSession). Model(model.DefaultSession).
Select("COUNT(DISTINCT asset_id, account_id) as connect, COUNT(DISTINCT uid) as user, COUNT(DISTINCT gateway_id) as gateway, COUNT(*) as session"). Select("COUNT(DISTINCT asset_id, account_id) as connect, COUNT(DISTINCT uid) as user, COUNT(DISTINCT gateway_id) as gateway, COUNT(*) as session").
Where("status = 1"). Where("status = 1").
@@ -75,13 +75,13 @@ func (c *Controller) StatCount(ctx *gin.Context) {
Error Error
}) })
eg.Go(func() error { eg.Go(func() error {
return mysql.DB.Model(model.DefaultAsset).Count(&stat.TotalAsset).Error return dbpkg.DB.Model(model.DefaultAsset).Count(&stat.TotalAsset).Error
}) })
eg.Go(func() error { eg.Go(func() error {
return mysql.DB.Model(model.DefaultAsset).Where("connectable = 1").Count(&stat.Asset).Error return dbpkg.DB.Model(model.DefaultAsset).Where("connectable = 1").Count(&stat.Asset).Error
}) })
eg.Go(func() error { eg.Go(func() error {
return mysql.DB.Model(model.DefaultGateway).Count(&stat.TotalGateway).Error return dbpkg.DB.Model(model.DefaultGateway).Count(&stat.TotalGateway).Error
}) })
if err := eg.Wait(); err != nil { if err := eg.Wait(); err != nil {
@@ -90,7 +90,7 @@ func (c *Controller) StatCount(ctx *gin.Context) {
} }
stat.Gateway = lo.Ternary(stat.Gateway <= stat.TotalGateway, stat.Gateway, stat.TotalGateway) stat.Gateway = lo.Ternary(stat.Gateway <= stat.TotalGateway, stat.Gateway, stat.TotalGateway)
redis.SetEx(ctx, key, stat, time.Minute) cache.SetEx(ctx, key, stat, time.Minute)
ctx.JSON(http.StatusOK, NewHttpResponseWithData(stat)) ctx.JSON(http.StatusOK, NewHttpResponseWithData(stat))
} }
@@ -117,12 +117,12 @@ func (c *Controller) StatAccount(ctx *gin.Context) {
stat := make([]*model.StatAccount, 0) stat := make([]*model.StatAccount, 0)
key := "stat-account-" + ctx.Query("type") key := "stat-account-" + ctx.Query("type")
if redis.Get(ctx, key, stat) == nil { if cache.Get(ctx, key, stat) == nil {
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
return return
} }
err := mysql.DB. err := dbpkg.DB.
Model(&model.Account{}). Model(&model.Account{}).
Select("account.name, COUNT(*) AS count"). Select("account.name, COUNT(*) AS count").
Joins("LEFT JOIN session ON account.id = session.account_id"). Joins("LEFT JOIN session ON account.id = session.account_id").
@@ -137,7 +137,7 @@ func (c *Controller) StatAccount(ctx *gin.Context) {
return return
} }
redis.SetEx(ctx, key, stat, time.Minute) cache.SetEx(ctx, key, stat, time.Minute)
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
} }
@@ -170,11 +170,11 @@ func (c *Controller) StatAsset(ctx *gin.Context) {
stat := make([]*model.StatAsset, 0) stat := make([]*model.StatAsset, 0)
key := "stat-asset-" + ctx.Query("type") key := "stat-asset-" + ctx.Query("type")
if redis.Get(ctx, key, stat) == nil { if cache.Get(ctx, key, stat) == nil {
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
return return
} }
err := mysql.DB. err := dbpkg.DB.
Model(model.DefaultSession). Model(model.DefaultSession).
Select("COUNT(DISTINCT asset_id, uid) AS connect, COUNT(*) AS session, COUNT(DISTINCT asset_id) AS asset, COUNT(DISTINCT uid) AS user, DATE_FORMAT(created_at, ?) AS time", dateFmt). Select("COUNT(DISTINCT asset_id, uid) AS connect, COUNT(*) AS session, COUNT(DISTINCT asset_id) AS asset, COUNT(DISTINCT uid) AS user, DATE_FORMAT(created_at, ?) AS time", dateFmt).
Where("session.created_at >= ? AND session.created_at <= ?", start, end). Where("session.created_at >= ? AND session.created_at <= ?", start, end).
@@ -196,7 +196,7 @@ func (c *Controller) StatAsset(ctx *gin.Context) {
sort.Slice(stat, func(i, j int) bool { return stat[i].Time < stat[j].Time }) sort.Slice(stat, func(i, j int) bool { return stat[i].Time < stat[j].Time })
redis.SetEx(ctx, key, stat, time.Minute) cache.SetEx(ctx, key, stat, time.Minute)
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
} }
@@ -212,7 +212,7 @@ func (c *Controller) StatCountOfUser(ctx *gin.Context) {
eg := &errgroup.Group{} eg := &errgroup.Group{}
eg.Go(func() error { eg.Go(func() error {
return mysql.DB. return dbpkg.DB.
Model(model.DefaultSession). Model(model.DefaultSession).
Select("COUNT(DISTINCT asset_id, account_id) as connect, COUNT(DISTINCT asset_id) as asset, COUNT(*) as session"). Select("COUNT(DISTINCT asset_id, account_id) as connect, COUNT(DISTINCT asset_id) as asset, COUNT(*) as session").
Where("status = 1"). Where("status = 1").
@@ -222,7 +222,7 @@ func (c *Controller) StatCountOfUser(ctx *gin.Context) {
}) })
eg.Go(func() error { eg.Go(func() error {
isAdmin := acl.IsAdmin(currentUser) isAdmin := acl.IsAdmin(currentUser)
assets, err := util.GetAllFromCacheDb(ctx, model.DefaultAsset) assets, err := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
if !isAdmin { if !isAdmin {
assetIds, err := GetAssetIdsByAuthorization(ctx) assetIds, err := GetAssetIdsByAuthorization(ctx)
if err != nil { if err != nil {
@@ -250,12 +250,12 @@ func (c *Controller) StatCountOfUser(ctx *gin.Context) {
func (c *Controller) StatRankOfUser(ctx *gin.Context) { func (c *Controller) StatRankOfUser(ctx *gin.Context) {
stat := make([]*model.StatRankOfUser, 0) stat := make([]*model.StatRankOfUser, 0)
key := "stat-rank-user" key := "stat-rank-user"
if redis.Get(ctx, key, stat) == nil { if cache.Get(ctx, key, stat) == nil {
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
return return
} }
if err := mysql.DB. if err := dbpkg.DB.
Model(model.DefaultSession). Model(model.DefaultSession).
Select("uid, COUNT(*) AS count, MAX(created_at) AS last_time"). Select("uid, COUNT(*) AS count, MAX(created_at) AS last_time").
Group("uid"). Group("uid").
@@ -267,7 +267,7 @@ func (c *Controller) StatRankOfUser(ctx *gin.Context) {
return return
} }
redis.SetEx(ctx, key, stat, time.Minute) cache.SetEx(ctx, key, stat, time.Minute)
ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat))) ctx.JSON(http.StatusOK, NewHttpResponseWithData(toListData(stat)))
} }
@@ -281,11 +281,11 @@ func toListData[T any](data []T) *ListData {
func nodeCountAsset() (m map[int]int64, err error) { func nodeCountAsset() (m map[int]int64, err error) {
assets := make([]*model.AssetIdPid, 0) assets := make([]*model.AssetIdPid, 0)
if err = mysql.DB.Model(model.DefaultAsset).Find(&assets).Error; err != nil { if err = dbpkg.DB.Model(model.DefaultAsset).Find(&assets).Error; err != nil {
return return
} }
nodes := make([]*model.Node, 0) nodes := make([]*model.Node, 0)
if err = mysql.DB.Model(model.DefaultNode).Find(&nodes).Error; err != nil { if err = dbpkg.DB.Model(model.DefaultNode).Find(&nodes).Error; err != nil {
return return
} }
m = make(map[int]int64) m = make(map[int]int64)

View File

@@ -0,0 +1,69 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/api/controller"
"github.com/veops/oneterm/pkg/logger"
)
var (
errUnauthorized = &controller.ApiError{Code: controller.ErrUnauthorized}
)
func AuthMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
var (
sess *acl.Session
err error
cookie string
)
m := make(map[string]any)
ctx.ShouldBindBodyWithJSON(&m)
if ctx.Request.Method == "GET" {
if _, ok := ctx.GetQuery("_key"); ok {
m["_key"] = ctx.Query("_key")
m["_secret"] = ctx.Query("_secret")
}
}
if _, ok := m["_key"]; ok {
sess, err = acl.AuthWithKey(ctx.Request.URL.Path, m)
if err != nil {
logger.L().Error("cannot authwithkey", zap.Error(err))
ctx.AbortWithError(http.StatusUnauthorized, errUnauthorized)
return
}
ctx.Set("isAuthWithKey", true)
} else {
cookie, err = ctx.Cookie("session")
if err != nil || cookie == "" {
logger.L().Error("cannot get cookie.session", zap.Error(err))
ctx.AbortWithError(http.StatusUnauthorized, errUnauthorized)
return
}
sess, err = acl.ParseCookie(cookie)
}
if err != nil {
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
ctx.Set("session", sess)
ctx.Next()
}
}
func authAdmin() gin.HandlerFunc {
return func(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
if !acl.IsAdmin(currentUser) {
ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
}
}

View File

@@ -0,0 +1,63 @@
package middleware
import (
"bytes"
"encoding/json"
"strings"
"github.com/gin-gonic/gin"
"github.com/nicksnyder/go-i18n/v2/i18n"
"github.com/veops/oneterm/internal/api/controller"
myi18n "github.com/veops/oneterm/internal/i18n"
)
type bodyWriter struct {
gin.ResponseWriter
body *bytes.Buffer
}
func (w bodyWriter) Write(b []byte) (int, error) {
return w.body.Write(b)
}
func Error2RespMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
if strings.Contains(ctx.Request.URL.String(), "session/replay") {
ctx.Next()
return
}
wb := &bodyWriter{
body: &bytes.Buffer{},
ResponseWriter: ctx.Writer,
}
ctx.Writer = wb
ctx.Next()
obj := make(map[string]any)
json.Unmarshal(wb.body.Bytes(), &obj)
if len(ctx.Errors) > 0 {
if v, ok := obj["code"]; !ok || v == 0 {
obj["code"] = ctx.Writer.Status()
}
if v, ok := obj["message"]; !ok || v == "" {
e := ctx.Errors.Last().Err
obj["message"] = e.Error()
ae, ok := e.(*controller.ApiError)
if ok {
lang := ctx.PostForm("lang")
accept := ctx.GetHeader("Accept-Language")
localizer := i18n.NewLocalizer(myi18n.Bundle, lang, accept)
obj["message"] = ae.Message(localizer)
}
}
}
bs, _ := json.Marshal(obj)
wb.ResponseWriter.Write(bs)
}
}

View File

@@ -0,0 +1,27 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"github.com/veops/oneterm/pkg/logger"
)
func LoggerMiddleware() gin.HandlerFunc {
return func(ctx *gin.Context) {
start := time.Now()
ctx.Next()
cost := time.Since(start)
logger.L().Info(ctx.Request.URL.String(),
zap.String("method", ctx.Request.Method),
zap.Int("status", ctx.Writer.Status()),
zap.String("ip", ctx.ClientIP()),
zap.Duration("cost", cost),
)
}
}

View File

@@ -1,38 +1,26 @@
package api package router
import ( import (
"context"
"fmt"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
swaggerFiles "github.com/swaggo/files" swaggerFiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger" ginSwagger "github.com/swaggo/gin-swagger"
"go.uber.org/zap"
"github.com/veops/oneterm/api/controller" "github.com/veops/oneterm/internal/api/controller"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/api/docs"
"github.com/veops/oneterm/docs" "github.com/veops/oneterm/internal/api/middleware"
"github.com/veops/oneterm/logger"
) )
var ( func SetupRouter(r *gin.Engine) {
ctx, cancel = context.WithCancel(context.Background())
srv = &http.Server{}
)
func RunApi() error {
c := controller.Controller{}
r := gin.New()
r.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"}) r.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"})
r.MaxMultipartMemory = 128 << 20 r.MaxMultipartMemory = 128 << 20
r.Use(gin.Recovery(), ginLogger()) r.Use(gin.Recovery(), middleware.LoggerMiddleware())
docs.SwaggerInfo.Title = "ONETERM API" docs.SwaggerInfo.Title = "ONETERM API"
docs.SwaggerInfo.BasePath = "/api/oneterm/v1" docs.SwaggerInfo.BasePath = "/api/oneterm/v1"
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
v1 := r.Group("/api/oneterm/v1", Error2Resp(), auth()) c := controller.Controller{}
v1 := r.Group("/api/oneterm/v1", middleware.Error2RespMiddleware(), middleware.AuthMiddleware())
{ {
account := v1.Group("account") account := v1.Group("account")
{ {
@@ -135,7 +123,8 @@ func RunApi() error {
share.DELETE("/:id", c.DeleteShare) share.DELETE("/:id", c.DeleteShare)
share.GET("", c.GetShare) share.GET("", c.GetShare)
} }
r.GET("/api/oneterm/v1/share/connect/:uuid", Error2Resp(), c.ConnectShare)
r.GET("/api/oneterm/v1/share/connect/:uuid", middleware.Error2RespMiddleware(), c.ConnectShare)
authorization := v1.Group("/authorization") authorization := v1.Group("/authorization")
{ {
@@ -145,16 +134,5 @@ func RunApi() error {
} }
} }
srv.Addr = fmt.Sprintf("%s:%d", conf.Cfg.Http.Host, conf.Cfg.Http.Port) return
srv.Handler = r
err := srv.ListenAndServe()
if err != nil {
logger.L().Fatal("start http failed", zap.Error(err))
}
return err
}
func StopApi() {
defer cancel()
srv.Shutdown(ctx)
} }

View File

@@ -11,8 +11,8 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/logger"
) )
var ( var (

View File

@@ -9,10 +9,11 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"github.com/spf13/cast" "github.com/spf13/cast"
"github.com/veops/oneterm/conf"
ggateway "github.com/veops/oneterm/gateway" ggateway "github.com/veops/oneterm/internal/gateway"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/pkg/logger"
) )
const ( const (
@@ -44,7 +45,15 @@ type Tunnel struct {
} }
func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, asset *model.Asset, account *model.Account, gateway *model.Gateway) (t *Tunnel, err error) { func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, asset *model.Asset, account *model.Account, gateway *model.Gateway) (t *Tunnel, err error) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", conf.Cfg.Guacd.Host, conf.Cfg.Guacd.Port), time.Second*3) var hostPort string
if strings.Contains(config.Cfg.Guacd.Host, ":") {
// IPv6 address
hostPort = fmt.Sprintf("[%s]:%d", config.Cfg.Guacd.Host, config.Cfg.Guacd.Port)
} else {
// IPv4 address or hostname
hostPort = fmt.Sprintf("%s:%d", config.Cfg.Guacd.Host, config.Cfg.Guacd.Port)
}
conn, err := net.DialTimeout("tcp", hostPort, time.Second*3)
if err != nil { if err != nil {
return return
} }

View File

@@ -8,7 +8,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/pkg/logger"
) )
var ( var (
@@ -19,7 +19,7 @@ var (
func init() { func init() {
Bundle.RegisterUnmarshalFunc("toml", toml.Unmarshal) Bundle.RegisterUnmarshalFunc("toml", toml.Unmarshal)
for _, lang := range langs { for _, lang := range langs {
_, err := Bundle.LoadMessageFile(fmt.Sprintf("./translate/active.%s.toml", lang)) _, err := Bundle.LoadMessageFile(fmt.Sprintf("./locales/active.%s.toml", lang))
if err != nil { if err != nil {
logger.L().Error("load i18n message failed", zap.Error(err)) logger.L().Error("load i18n message failed", zap.Error(err))
} }

View File

@@ -1,13 +1,13 @@
package util package repository
import ( import (
"context" "context"
"fmt" "fmt"
"time" "time"
redis "github.com/veops/oneterm/cache" "github.com/veops/oneterm/internal/model"
mysql "github.com/veops/oneterm/db" redis "github.com/veops/oneterm/pkg/cache"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/db"
) )
func GetAllFromCacheDb[T model.Model](ctx context.Context, m T) (res []T, err error) { func GetAllFromCacheDb[T model.Model](ctx context.Context, m T) (res []T, err error) {
@@ -15,7 +15,7 @@ func GetAllFromCacheDb[T model.Model](ctx context.Context, m T) (res []T, err er
if err = redis.Get(ctx, k, &res); err == nil { if err = redis.Get(ctx, k, &res); err == nil {
return return
} }
if err = mysql.DB.Model(m).Find(&res).Error; err != nil { if err = db.DB.Model(m).Find(&res).Error; err != nil {
return return
} }
redis.SetEx(ctx, k, res, time.Hour) redis.SetEx(ctx, k, res, time.Hour)

View File

@@ -0,0 +1,25 @@
package schedule
import (
"time"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/cache"
dbpkg "github.com/veops/oneterm/pkg/db"
)
func UpdateConfig() {
cfg := &model.Config{}
defer func() {
cache.SetEx(ctx, "config", cfg, time.Hour)
model.GlobalConfig.Store(cfg)
}()
err := cache.Get(ctx, "config", cfg)
if err == nil {
return
}
err = dbpkg.DB.Model(cfg).First(cfg).Error
if err != nil {
return
}
}

View File

@@ -10,11 +10,12 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
mysql "github.com/veops/oneterm/db" ggateway "github.com/veops/oneterm/internal/gateway"
ggateway "github.com/veops/oneterm/gateway" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/internal/service"
"github.com/veops/oneterm/model" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/logger"
"github.com/veops/oneterm/pkg/utils"
) )
func UpdateConnectables(ids ...int) (err error) { func UpdateConnectables(ids ...int) (err error) {
@@ -24,7 +25,7 @@ func UpdateConnectables(ids ...int) (err error) {
} }
}() }()
assets := make([]*model.Asset, 0) assets := make([]*model.Asset, 0)
db := mysql.DB. db := dbpkg.DB.
Model(assets) Model(assets)
if len(ids) > 0 { if len(ids) > 0 {
db = db.Where("id IN ?", ids) db = db.Where("id IN ?", ids)
@@ -39,7 +40,7 @@ func UpdateConnectables(ids ...int) (err error) {
gids := lo.Without(lo.Uniq(lo.Map(assets, func(a *model.Asset, _ int) int { return a.GatewayId })), 0) gids := lo.Without(lo.Uniq(lo.Map(assets, func(a *model.Asset, _ int) int { return a.GatewayId })), 0)
gateways := make([]*model.Gateway, 0) gateways := make([]*model.Gateway, 0)
if len(gids) > 0 { if len(gids) > 0 {
if err = mysql.DB. if err = dbpkg.DB.
Model(gateways). Model(gateways).
Where("id IN ?", gids). Where("id IN ?", gids).
Find(&gateways).Error; err != nil { Find(&gateways).Error; err != nil {
@@ -48,9 +49,9 @@ func UpdateConnectables(ids ...int) (err error) {
} }
} }
for _, g := range gateways { for _, g := range gateways {
g.Password = util.DecryptAES(g.Password) g.Password = utils.DecryptAES(g.Password)
g.Pk = util.DecryptAES(g.Pk) g.Pk = utils.DecryptAES(g.Pk)
g.Phrase = util.DecryptAES(g.Phrase) g.Phrase = utils.DecryptAES(g.Phrase)
} }
gatewayMap := lo.SliceToMap(gateways, func(g *model.Gateway) (int, *model.Gateway) { return g.Id, g }) gatewayMap := lo.SliceToMap(gateways, func(g *model.Gateway) (int, *model.Gateway) { return g.Id, g })
@@ -65,12 +66,12 @@ func UpdateConnectables(ids ...int) (err error) {
} }
defer ggateway.GetGatewayManager().Close(sids...) defer ggateway.GetGatewayManager().Close(sids...)
if len(oks) > 0 { if len(oks) > 0 {
if err := mysql.DB.Model(assets).Where("id IN ?", oks).Update("connectable", true).Error; err != nil { if err := dbpkg.DB.Model(assets).Where("id IN ?", oks).Update("connectable", true).Error; err != nil {
logger.L().Debug("update connectable to ok failed", zap.Error(err)) logger.L().Debug("update connectable to ok failed", zap.Error(err))
} }
} }
if len(oks) < len(all) { if len(oks) < len(all) {
if err := mysql.DB.Model(assets).Where("id IN ?", lo.Without(all, oks...)).Update("connectable", false).Error; err != nil { if err := dbpkg.DB.Model(assets).Where("id IN ?", lo.Without(all, oks...)).Update("connectable", false).Error; err != nil {
logger.L().Debug("update connectable to fail failed", zap.Error(err)) logger.L().Debug("update connectable to fail failed", zap.Error(err))
} }
} }
@@ -80,15 +81,20 @@ func UpdateConnectables(ids ...int) (err error) {
func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) { func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) {
sid = uuid.New().String() sid = uuid.New().String()
ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string { return strings.Split(p, ":")[0] }), ",") ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string { return strings.Split(p, ":")[0] }), ",")
ip, port, err := util.Proxy(true, sid, ps, asset, gateway) ip, port, err := service.Proxy(true, sid, ps, asset, gateway)
if err != nil { if err != nil {
logger.L().Debug("connectable proxy failed", zap.String("protocol", ps), zap.Error(err)) logger.L().Debug("connectable proxy failed", zap.String("protocol", ps), zap.Error(err))
return return
} }
addr := fmt.Sprintf("%s:%d", ip, port) var hostPort string
conn, err := net.DialTimeout("tcp", addr, time.Second) if strings.Contains(ip, ":") {
hostPort = fmt.Sprintf("[%s]:%d", ip, port)
} else {
hostPort = fmt.Sprintf("%s:%d", ip, port)
}
conn, err := net.DialTimeout("tcp", hostPort, time.Second)
if err != nil { if err != nil {
logger.L().Debug("dail failed", zap.String("addr", addr), zap.Error(err)) logger.L().Debug("dail failed", zap.String("addr", hostPort), zap.Error(err))
return return
} }
defer conn.Close() defer conn.Close()

View File

@@ -1,4 +1,4 @@
package file package service
import ( import (
"fmt" "fmt"
@@ -8,8 +8,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/veops/oneterm/util"
) )
var ( var (
@@ -70,17 +68,17 @@ func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client,
return return
} }
asset, account, gateway, err := util.GetAAG(assetId, accountId) asset, account, gateway, err := GetAAG(assetId, accountId)
if err != nil { if err != nil {
return return
} }
ip, port, err := util.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway) ip, port, err := Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
if err != nil { if err != nil {
return return
} }
auth, err := util.GetAuth(account) auth, err := GetAuth(account)
if err != nil { if err != nil {
return return
} }

View File

@@ -1,35 +1,36 @@
package util package service
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/spf13/cast"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/spf13/cast" ggateway "github.com/veops/oneterm/internal/gateway"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/model"
ggateway "github.com/veops/oneterm/gateway" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/utils"
) )
func GetAAG(assetId int, accountId int) (asset *model.Asset, account *model.Account, gateway *model.Gateway, err error) { func GetAAG(assetId int, accountId int) (asset *model.Asset, account *model.Account, gateway *model.Gateway, err error) {
asset, account, gateway = &model.Asset{}, &model.Account{}, &model.Gateway{} asset, account, gateway = &model.Asset{}, &model.Account{}, &model.Gateway{}
if err = mysql.DB.Model(asset).Where("id = ?", assetId).First(asset).Error; err != nil { if err = dbpkg.DB.Model(asset).Where("id = ?", assetId).First(asset).Error; err != nil {
return return
} }
if err = mysql.DB.Model(account).Where("id = ?", accountId).First(account).Error; err != nil { if err = dbpkg.DB.Model(account).Where("id = ?", accountId).First(account).Error; err != nil {
return return
} }
account.Password = DecryptAES(account.Password) account.Password = utils.DecryptAES(account.Password)
account.Pk = DecryptAES(account.Pk) account.Pk = utils.DecryptAES(account.Pk)
account.Phrase = DecryptAES(account.Phrase) account.Phrase = utils.DecryptAES(account.Phrase)
if asset.GatewayId != 0 { if asset.GatewayId != 0 {
if err = mysql.DB.Model(gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil { if err = dbpkg.DB.Model(gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil {
return return
} }
gateway.Password = DecryptAES(gateway.Password) gateway.Password = utils.DecryptAES(gateway.Password)
gateway.Pk = DecryptAES(gateway.Pk) gateway.Pk = utils.DecryptAES(gateway.Pk)
gateway.Phrase = DecryptAES(gateway.Phrase) gateway.Phrase = utils.DecryptAES(gateway.Phrase)
} }
return return

View File

@@ -7,10 +7,11 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"github.com/veops/go-ansiterm" "github.com/veops/go-ansiterm"
mysql "github.com/veops/oneterm/db"
"github.com/veops/oneterm/logger"
"github.com/veops/oneterm/model"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/internal/model"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
) )
var ( var (
@@ -118,7 +119,7 @@ func (p *Parser) WriteDb() {
Cmd: p.lastCmd, Cmd: p.lastCmd,
Result: p.lastRes, Result: p.lastRes,
} }
err := mysql.DB.Model(m).Create(m).Error err := dbpkg.DB.Model(m).Create(m).Error
if err != nil { if err != nil {
logger.L().Error("write session cmd failed", zap.Error(err), zap.Any("cmd", *m)) logger.L().Error("write session cmd failed", zap.Error(err), zap.Any("cmd", *m))
} }

View File

@@ -9,7 +9,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/pkg/logger"
) )
type Asciinema struct { type Asciinema struct {

View File

@@ -15,10 +15,10 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"github.com/veops/oneterm/api/guacd" "github.com/veops/oneterm/internal/guacd"
mysql "github.com/veops/oneterm/db" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/logger" dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/logger"
) )
var ( var (
@@ -26,8 +26,9 @@ var (
) )
func init() { func init() {
// After system restart, set all online sessions to offline
sessions := make([]*Session, 0) sessions := make([]*Session, 0)
if err := mysql.DB. if err := dbpkg.DB.
Model(sessions). Model(sessions).
Where("status = ?", model.SESSIONSTATUS_ONLINE). Where("status = ?", model.SESSIONSTATUS_ONLINE).
Find(&sessions). Find(&sessions).
@@ -158,7 +159,7 @@ func NewSession(ctx context.Context) *Session {
} }
func UpsertSession(data *Session) (err error) { func UpsertSession(data *Session) (err error) {
return mysql.DB. return dbpkg.DB.
Clauses(clause.OnConflict{ Clauses(clause.OnConflict{
DoUpdates: clause.AssignmentColumns([]string{"status", "closed_at"}), DoUpdates: clause.AssignmentColumns([]string{"status", "closed_at"}),
}). }).

View File

@@ -15,10 +15,10 @@ import (
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/pkg/logger"
) )
func handler(sess ssh.Session) { func handler(sess ssh.Session) {
@@ -64,7 +64,7 @@ func handler(sess ssh.Session) {
} }
func signer() ssh.Signer { func signer() ssh.Signer {
s, err := gossh.ParsePrivateKey([]byte(conf.Cfg.Ssh.PrivateKey)) s, err := gossh.ParsePrivateKey([]byte(config.Cfg.Ssh.PrivateKey))
if err != nil { if err != nil {
logger.L().Fatal("failed parse signer", zap.Error(err)) logger.L().Fatal("failed parse signer", zap.Error(err))
} }

View File

@@ -7,9 +7,9 @@ import (
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/utils"
) )
var ( var (
@@ -19,15 +19,15 @@ var (
func init() { func init() {
server = &ssh.Server{ server = &ssh.Server{
Addr: fmt.Sprintf("%s:%d", conf.Cfg.Ssh.Host, conf.Cfg.Ssh.Port), Addr: fmt.Sprintf("%s:%d", config.Cfg.Ssh.Host, config.Cfg.Ssh.Port),
Handler: handler, Handler: handler,
PasswordHandler: func(ctx ssh.Context, password string) bool { PasswordHandler: func(ctx ssh.Context, password string) bool {
sess, err := acl.LoginByPassword(ctx, ctx.User(), password, util.IpFromNetAddr(ctx.RemoteAddr())) sess, err := acl.LoginByPassword(ctx, ctx.User(), password, utils.IpFromNetAddr(ctx.RemoteAddr()))
ctx.SetValue("session", sess) ctx.SetValue("session", sess)
return err == nil return err == nil
}, },
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
sess, err := acl.LoginByPublicKey(ctx, ctx.User(), string(gossh.MarshalAuthorizedKey(key)), util.IpFromNetAddr(ctx.RemoteAddr())) sess, err := acl.LoginByPublicKey(ctx, ctx.User(), string(gossh.MarshalAuthorizedKey(key)), utils.IpFromNetAddr(ctx.RemoteAddr()))
ctx.SetValue("session", sess) ctx.SetValue("session", sess)
return err == nil return err == nil
}, },

View File

@@ -20,14 +20,14 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/veops/oneterm/acl" "github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/api/controller" "github.com/veops/oneterm/internal/api/controller"
redis "github.com/veops/oneterm/cache" "github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/model" "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/session" "github.com/veops/oneterm/internal/sshsrv/textinput"
"github.com/veops/oneterm/sshsrv/textinput" "github.com/veops/oneterm/pkg/cache"
"github.com/veops/oneterm/util" "github.com/veops/oneterm/pkg/logger"
) )
const ( const (
@@ -230,11 +230,11 @@ func (m *view) possible() string {
func (m *view) refresh() { func (m *view) refresh() {
eg := &errgroup.Group{} eg := &errgroup.Group{}
eg.Go(func() (err error) { eg.Go(func() (err error) {
assets, err := util.GetAllFromCacheDb(m.gctx, model.DefaultAsset) assets, err := repository.GetAllFromCacheDb(m.gctx, model.DefaultAsset)
if err != nil { if err != nil {
return return
} }
accounts, err := util.GetAllFromCacheDb(m.gctx, model.DefaultAccount) accounts, err := repository.GetAllFromCacheDb(m.gctx, model.DefaultAccount)
if err != nil { if err != nil {
return return
} }
@@ -286,7 +286,7 @@ func (m *view) refresh() {
if len(m.cmds) != 0 { if len(m.cmds) != 0 {
return err return err
} }
m.cmds, err = redis.RC.LRange(m.Ctx, fmt.Sprintf(hisCmdsFmt, m.currentUser.GetUid()), -100, -1).Result() m.cmds, err = cache.RC.LRange(m.Ctx, fmt.Sprintf(hisCmdsFmt, m.currentUser.GetUid()), -100, -1).Result()
m.cmdsIdx = len(m.cmds) m.cmdsIdx = len(m.cmds)
return err return err
}) })
@@ -305,9 +305,9 @@ func (m *view) magicn() tea.Msg {
func (m *view) RecordHisCmd() { func (m *view) RecordHisCmd() {
k := fmt.Sprintf(hisCmdsFmt, m.currentUser.GetUid()) k := fmt.Sprintf(hisCmdsFmt, m.currentUser.GetUid())
redis.RC.RPush(m.Ctx, k, m.cmds) cache.RC.RPush(m.Ctx, k, m.cmds)
redis.RC.LTrim(m.Ctx, k, -100, -1) cache.RC.LTrim(m.Ctx, k, -100, -1)
redis.RC.Expire(m.Ctx, k, time.Hour*24*30) cache.RC.Expire(m.Ctx, k, time.Hour*24*30)
} }
type connector struct { type connector struct {

View File

@@ -1,4 +1,4 @@
package redis package cache
import ( import (
"context" "context"
@@ -9,8 +9,8 @@ import (
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/pkg/logger"
) )
var ( var (
@@ -20,10 +20,10 @@ var (
func init() { func init() {
ctx := context.Background() ctx := context.Background()
addr := fmt.Sprintf("%s:%d", conf.Cfg.Redis.Host, conf.Cfg.Redis.Port) addr := fmt.Sprintf("%s:%d", config.Cfg.Redis.Host, config.Cfg.Redis.Port)
RC = redis.NewClient(&redis.Options{ RC = redis.NewClient(&redis.Options{
Addr: addr, Addr: addr,
Password: conf.Cfg.Redis.Password, Password: config.Cfg.Redis.Password,
}) })
if _, err := RC.Ping(ctx).Result(); err != nil { if _, err := RC.Ping(ctx).Result(); err != nil {

View File

@@ -1,4 +1,4 @@
package conf package config
import ( import (
"fmt" "fmt"
@@ -75,6 +75,21 @@ type MysqlConfig struct {
Password string `yaml:"password"` Password string `yaml:"password"`
} }
type DatabaseConfig struct {
Type string `yaml:"type"` // mysql, postgres, tidb, tdsql, dm, default: mysql
Host string `yaml:"host"` // default: oneterm-mysql
Port string `yaml:"port"` // default: 3306
User string `yaml:"user"` // default: root
Password string `yaml:"password"` // default: root
Database string `yaml:"database"` // default: oneterm
Charset string `yaml:"charset"` // default: utf8mb4
MaxIdleConns int `yaml:"max_idle_conns"` // default: 10
MaxOpenConns int `yaml:"max_open_conns"` // default: 100
ConnMaxLife int `yaml:"conn_max_lifetime"` // seconds, default: 3600
ConnMaxIdle int `yaml:"conn_max_idle_time"` // seconds, default: 600
SSLMode string `yaml:"ssl_mode"` // disable, prefer, require, verify-ca, verify-full, default: disable
}
type KV struct { type KV struct {
Key string Key string
Value string Value string
@@ -125,14 +140,15 @@ type GuacdConfig struct {
} }
type ConfigYaml struct { type ConfigYaml struct {
Mode string `yaml:"mode"` Mode string `yaml:"mode"`
I18nDir string `yaml:"i18nDir"` I18nDir string `yaml:"i18nDir"`
Log LogConfig `yaml:"log"` Log LogConfig `yaml:"log"`
Redis RedisConfig `yaml:"redis"` Redis RedisConfig `yaml:"redis"`
Mysql MysqlConfig `yaml:"mysql"` Mysql MysqlConfig `yaml:"mysql"`
Guacd GuacdConfig `yaml:"guacd"` Database DatabaseConfig `yaml:"database"`
Http HttpConfig `yaml:"http"` Guacd GuacdConfig `yaml:"guacd"`
Ssh SshConfig `yaml:"ssh"` Http HttpConfig `yaml:"http"`
Auth Auth `yaml:"auth"` Ssh SshConfig `yaml:"ssh"`
SecretKey string `yaml:"secretKey"` Auth Auth `yaml:"auth"`
SecretKey string `yaml:"secretKey"`
} }

214
backend/pkg/db/database.go Normal file
View File

@@ -0,0 +1,214 @@
package db
import (
"context"
"fmt"
"strings"
"sync"
"time"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
gormLogger "gorm.io/gorm/logger"
"github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/pkg/logger"
)
var (
DB *gorm.DB
dbOnce sync.Once
)
type DBType string
const (
MySQL DBType = "mysql"
Postgres DBType = "postgres"
TiDB DBType = "tidb"
TDSQL DBType = "tdsql"
)
type Config struct {
Type DBType
Host string
Port string
User string
Password string
Database string
Charset string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
SSLMode string
}
func ConfigFromGlobal() Config {
dbType := DBType(config.Cfg.Database.Type)
if dbType == "" {
dbType = MySQL
}
return Config{
Type: dbType,
Host: config.Cfg.Database.Host,
Port: config.Cfg.Database.Port,
User: config.Cfg.Database.User,
Password: config.Cfg.Database.Password,
Database: config.Cfg.Database.Database,
Charset: config.Cfg.Database.Charset,
MaxIdleConns: config.Cfg.Database.MaxIdleConns,
MaxOpenConns: config.Cfg.Database.MaxOpenConns,
ConnMaxLifetime: time.Duration(config.Cfg.Database.ConnMaxLife) * time.Second,
ConnMaxIdleTime: time.Duration(config.Cfg.Database.ConnMaxIdle) * time.Second,
SSLMode: config.Cfg.Database.SSLMode,
}
}
func (c *Config) DSN() string {
switch c.Type {
case Postgres:
return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
c.Host, c.Port, c.User, c.Password, c.Database, c.SSLMode)
default: // MySQL, TiDB, TDSQL
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local",
c.User, c.Password, c.Host, c.Port, c.Database, c.Charset)
}
}
func Open(cfg Config) (*gorm.DB, error) {
var dialector gorm.Dialector
switch cfg.Type {
case Postgres:
dialector = postgres.Open(cfg.DSN())
default: // MySQL, TiDB, TDSQL
dialector = mysql.Open(cfg.DSN())
}
db, err := gorm.Open(dialector, &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
})
if err != nil {
return nil, fmt.Errorf("open database failed: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("get sql.DB failed: %w", err)
}
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
sqlDB.SetConnMaxIdleTime(cfg.ConnMaxIdleTime)
return db, nil
}
func Init(cfg Config, models ...interface{}) error {
var err error
dbOnce.Do(func() {
DB, err = Open(cfg)
if err != nil {
err = fmt.Errorf("init database failed: %w", err)
return
}
if len(models) > 0 {
if err = DB.AutoMigrate(models...); err != nil {
err = fmt.Errorf("auto migrate failed: %w", err)
return
}
}
})
return err
}
func GetDB() *gorm.DB {
if DB == nil {
panic("database not initialized, call Init() first")
}
return DB
}
func WithContext(ctx context.Context) *gorm.DB {
return GetDB().WithContext(ctx)
}
func Transaction(fn func(tx *gorm.DB) error) error {
return GetDB().Transaction(fn)
}
func Close() error {
if DB == nil {
return nil
}
sqlDB, err := DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
func DropIndex(value interface{}, indexName string) error {
db := GetDB()
if !db.Migrator().HasIndex(value, indexName) {
return nil
}
err := db.Migrator().DropIndex(value, indexName)
if err != nil && !strings.Contains(err.Error(), "1091") {
return fmt.Errorf("drop index %s failed: %w", indexName, err)
}
return nil
}
// Initialize (backward compatibility)
func init() {
if config.Cfg == nil {
return
}
// Compatibility with old configurations
if config.Cfg.Database.Host == "" && config.Cfg.Mysql.Host != "" {
// Use old MySQL configuration
cfg := Config{
Type: MySQL,
Host: config.Cfg.Mysql.Host,
Port: config.Cfg.Mysql.Port,
User: config.Cfg.Mysql.User,
Password: config.Cfg.Mysql.Password,
Database: "oneterm",
Charset: "utf8mb4",
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: time.Hour,
ConnMaxIdleTime: time.Minute * 10,
}
if err := Init(cfg); err != nil {
logger.L().Fatal("init database failed", zap.Error(err))
}
return
}
// Use new configuration
if config.Cfg.Database.Host != "" {
cfg := ConfigFromGlobal()
if err := Init(cfg); err != nil {
logger.L().Fatal("init database failed", zap.Error(err))
}
}
}

View File

@@ -7,12 +7,12 @@ import (
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/pkg/config"
) )
func init() { func init() {
level := zapcore.DebugLevel level := zapcore.DebugLevel
switch conf.Cfg.Log.Level { switch config.Cfg.Log.Level {
case "error": case "error":
level = zapcore.ErrorLevel level = zapcore.ErrorLevel
case "warn": case "warn":
@@ -38,7 +38,7 @@ func init() {
zapcore.AddSync(fw), zapcore.AddSync(fw),
level, level,
)} )}
if conf.Cfg.Log.ConsoleEnable { if config.Cfg.Log.ConsoleEnable {
cores = append(cores, zapcore.NewCore( cores = append(cores, zapcore.NewCore(
encoder, encoder,
zapcore.AddSync(zapcore.Lock(os.Stderr)), zapcore.AddSync(zapcore.Lock(os.Stderr)),

View File

@@ -14,9 +14,9 @@ import (
"github.com/spf13/cast" "github.com/spf13/cast"
"go.uber.org/zap" "go.uber.org/zap"
redis "github.com/veops/oneterm/cache" "github.com/veops/oneterm/pkg/cache"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/logger" "github.com/veops/oneterm/pkg/logger"
) )
var ( var (
@@ -24,11 +24,11 @@ var (
) )
func GetAclToken(ctx context.Context) (res string, err error) { func GetAclToken(ctx context.Context) (res string, err error) {
res, err = redis.RC.Get(ctx, "aclToken").Result() res, err = cache.RC.Get(ctx, "aclToken").Result()
if err == nil { if err == nil {
return return
} }
aclConfig := conf.Cfg.Auth.Acl aclConfig := config.Cfg.Auth.Acl
url := fmt.Sprintf("%s%s", aclConfig.Url, "/acl/apps/token") url := fmt.Sprintf("%s%s", aclConfig.Url, "/acl/apps/token")
secretHash := md5.Sum([]byte(aclConfig.SecretKey)) secretHash := md5.Sum([]byte(aclConfig.SecretKey))
@@ -44,7 +44,7 @@ func GetAclToken(ctx context.Context) (res string, err error) {
} }
res = data["token"] res = data["token"]
_, err = redis.RC.SetNX(ctx, "aclToken", res, time.Hour).Result() _, err = cache.RC.SetNX(ctx, "aclToken", res, time.Hour).Result()
return return
} }

View File

@@ -1,4 +1,4 @@
package util package utils
import ( import (
"bytes" "bytes"
@@ -6,7 +6,7 @@ import (
"crypto/cipher" "crypto/cipher"
"encoding/base64" "encoding/base64"
"github.com/veops/oneterm/conf" "github.com/veops/oneterm/pkg/config"
) )
var ( var (
@@ -14,8 +14,8 @@ var (
) )
func init() { func init() {
key = []byte(conf.Cfg.Auth.Aes.Key) key = []byte(config.Cfg.Auth.Aes.Key)
iv = []byte(conf.Cfg.Auth.Aes.Iv) iv = []byte(config.Cfg.Auth.Aes.Iv)
} }
func EncryptAES(plainText string) string { func EncryptAES(plainText string) string {

View File

@@ -1,4 +1,4 @@
package util package utils
import ( import (
"testing" "testing"

View File

@@ -1,4 +1,4 @@
package util package utils
import "net" import "net"

View File

@@ -1,25 +0,0 @@
package schedule
import (
"time"
redis "github.com/veops/oneterm/cache"
mysql "github.com/veops/oneterm/db"
"github.com/veops/oneterm/model"
)
func UpdateConfig() {
cfg := &model.Config{}
defer func() {
redis.SetEx(ctx, "config", cfg, time.Hour)
model.GlobalConfig.Store(cfg)
}()
err := redis.Get(ctx, "config", cfg)
if err == nil {
return
}
err = mysql.DB.Model(cfg).First(cfg).Error
if err != nil {
return
}
}