mirror of
https://github.com/veops/oneterm.git
synced 2025-10-20 14:06:03 +08:00
fix(api): connectable
This commit is contained in:
@@ -109,7 +109,10 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
|
|||||||
db = db.Where("parent_id IN ?", parentIds)
|
db = db.Where("parent_id IN ?", parentIds)
|
||||||
}
|
}
|
||||||
|
|
||||||
if info && !acl.IsAdmin(currentUser) {
|
if info {
|
||||||
|
db = db.Select("id", "parent_id", "name", "ip", "protocols", "connectable", "authorization")
|
||||||
|
|
||||||
|
if !acl.IsAdmin(currentUser) {
|
||||||
ids, err := GetAssetIdsByAuthorization(ctx)
|
ids, err := GetAssetIdsByAuthorization(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
|
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
|
||||||
@@ -117,6 +120,7 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
db = db.Where("id IN ?", ids)
|
db = db.Where("id IN ?", ids)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
db = db.Order("name")
|
db = db.Order("name")
|
||||||
|
|
||||||
|
@@ -346,7 +346,7 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip, port, err := util.Proxy(sess.SessionId, "ssh", asset, gateway)
|
ip, port, err := util.Proxy(false, sess.SessionId, "ssh", asset, gateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -266,8 +266,8 @@ func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType
|
|||||||
}
|
}
|
||||||
if needAcl {
|
if needAcl {
|
||||||
md.SetResourceId(old.GetResourceId())
|
md.SetResourceId(old.GetResourceId())
|
||||||
fmt.Printf("%+v\n", old)
|
// fmt.Printf("%+v\n", old)
|
||||||
fmt.Printf("%+v\n", md)
|
// fmt.Printf("%+v\n", md)
|
||||||
if !hasPerm(ctx, md, resourceType, acl.WRITE) {
|
if !hasPerm(ctx, md, resourceType, acl.WRITE) {
|
||||||
ctx.AbortWithError(http.StatusForbidden, &ApiError{Code: ErrNoPerm, Data: map[string]any{"perm": acl.WRITE}})
|
ctx.AbortWithError(http.StatusForbidden, &ApiError{Code: ErrNoPerm, Data: map[string]any{"perm": acl.WRITE}})
|
||||||
return
|
return
|
||||||
|
@@ -98,7 +98,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
db = db.Where("id NOT IN ?", ids)
|
db = db.Where("id IN ?", ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
if id, ok := ctx.GetQuery("self_parent"); ok {
|
if id, ok := ctx.GetQuery("self_parent"); ok {
|
||||||
@@ -109,7 +109,9 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
|
|||||||
db = db.Where("id IN ?", ids)
|
db = db.Where("id IN ?", ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
if info && !acl.IsAdmin(currentUser) {
|
if info {
|
||||||
|
db = db.Select("id", "parent_id", "name")
|
||||||
|
if !acl.IsAdmin(currentUser) {
|
||||||
ids, err := GetNodeIdsByAuthorization(ctx)
|
ids, err := GetNodeIdsByAuthorization(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -119,6 +121,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
|
|||||||
}
|
}
|
||||||
db = db.Where("id IN ?", ids)
|
db = db.Where("id IN ?", ids)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
doGet(ctx, !info, db, conf.RESOURCE_NODE, nodePostHooks...)
|
doGet(ctx, !info, db, conf.RESOURCE_NODE, nodePostHooks...)
|
||||||
}
|
}
|
||||||
@@ -253,7 +256,7 @@ func nodeDelHook(ctx *gin.Context, id int) {
|
|||||||
ctx.AbortWithError(http.StatusBadRequest, err)
|
ctx.AbortWithError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
||||||
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode)
|
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -274,10 +277,12 @@ func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
|||||||
}
|
}
|
||||||
dfs(0, false)
|
dfs(0, false)
|
||||||
|
|
||||||
|
res = lo.Uniq(append(res, ids...))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
|
func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
|
||||||
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode)
|
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -290,10 +295,10 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error)
|
|||||||
t := make([]int, 0)
|
t := make([]int, 0)
|
||||||
var dfs func(int)
|
var dfs func(int)
|
||||||
dfs = func(x int) {
|
dfs = func(x int) {
|
||||||
|
t = append(t, x)
|
||||||
if lo.Contains(ids, x) {
|
if lo.Contains(ids, x) {
|
||||||
res = append(res, t...)
|
res = append(res, t...)
|
||||||
}
|
}
|
||||||
t = append(t, x)
|
|
||||||
for _, y := range g[x] {
|
for _, y := range g[x] {
|
||||||
dfs(y)
|
dfs(y)
|
||||||
}
|
}
|
||||||
@@ -301,27 +306,39 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error)
|
|||||||
}
|
}
|
||||||
dfs(0)
|
dfs(0)
|
||||||
|
|
||||||
res = lo.Uniq(res)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
|
|
||||||
res, err = handleNoSelfParent(ctx, ids...)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res = lo.Uniq(append(res, ids...))
|
res = lo.Uniq(append(res, ids...))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
||||||
res, err = handleNoSelfChild(ctx, ids...)
|
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res = lo.Uniq(append(res, ids...))
|
allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||||
|
|
||||||
|
res, err = handleSelfChild(ctx, ids...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res = lo.Uniq(lo.Without(allids, res...))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
|
||||||
|
nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||||
|
|
||||||
|
res, err = handleSelfParent(ctx, ids...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res = lo.Uniq(lo.Without(allids, res...))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -75,7 +75,7 @@ func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, port, err := util.Proxy(uuid.New().String(), "sftp,ssh", asset, gateway)
|
ip, port, err := util.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -92,7 +92,7 @@ func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, a
|
|||||||
t.Config.Parameters["recording-name"] = t.SessionId
|
t.Config.Parameters["recording-name"] = t.SessionId
|
||||||
}
|
}
|
||||||
if gateway != nil && gateway.Id != 0 && t.ConnectionId == "" {
|
if gateway != nil && gateway.Id != 0 && t.ConnectionId == "" {
|
||||||
t.gw, err = ggateway.GetGatewayManager().Open(t.SessionId, asset.Ip, cast.ToInt(port), gateway)
|
t.gw, err = ggateway.GetGatewayManager().Open(false, t.SessionId, asset.Ip, cast.ToInt(port), gateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return t, err
|
return t, err
|
||||||
}
|
}
|
||||||
|
@@ -17,7 +17,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
manager = &GateWayManager{
|
manager = &GateWayManager{
|
||||||
gateways: map[string]*GatewayTunnel{},
|
gatewayTunnels: map[string]*GatewayTunnel{},
|
||||||
sshClients: map[int]*ssh.Client{},
|
sshClients: map[int]*ssh.Client{},
|
||||||
sshClientsCount: map[int]int{},
|
sshClientsCount: map[int]int{},
|
||||||
mtx: sync.Mutex{},
|
mtx: sync.Mutex{},
|
||||||
@@ -28,8 +28,8 @@ func GetGatewayManager() *GateWayManager {
|
|||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGatewayBySessionId(sessionId string) *GatewayTunnel {
|
func GetGatewayTunnelBySessionId(sessionId string) *GatewayTunnel {
|
||||||
return manager.gateways[sessionId]
|
return manager.gatewayTunnels[sessionId]
|
||||||
}
|
}
|
||||||
|
|
||||||
type GatewayTunnel struct {
|
type GatewayTunnel struct {
|
||||||
@@ -45,13 +45,14 @@ type GatewayTunnel struct {
|
|||||||
Opened chan error
|
Opened chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gt *GatewayTunnel) Open() (err error) {
|
func (gt *GatewayTunnel) Open(isConnectable bool) (err error) {
|
||||||
go func() {
|
go func() {
|
||||||
<-time.After(time.Second * 3)
|
<-time.After(time.Second * 3)
|
||||||
logger.L().Debug("timeout 3 second close listener", zap.String("sessionId", gt.SessionId))
|
logger.L().Debug("timeout 3 second close listener", zap.String("sessionId", gt.SessionId))
|
||||||
gt.listener.Close()
|
gt.listener.Close()
|
||||||
}()
|
}()
|
||||||
defer func() {
|
defer func() {
|
||||||
|
logger.L().Debug("close listener", zap.String("sessionId", gt.SessionId), zap.Error(err))
|
||||||
gt.Opened <- err
|
gt.Opened <- err
|
||||||
}()
|
}()
|
||||||
gt.Opened <- nil
|
gt.Opened <- nil
|
||||||
@@ -65,11 +66,20 @@ func (gt *GatewayTunnel) Open() (err error) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
gt.RemoteConn, err = manager.sshClients[gt.GatewayId].DialContext(ctx, "tcp", remoteAddr)
|
gt.RemoteConn, err = manager.sshClients[gt.GatewayId].DialContext(ctx, "tcp", remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
defer func() {
|
||||||
|
if gt.LocalConn != nil {
|
||||||
defer gt.LocalConn.Close()
|
defer gt.LocalConn.Close()
|
||||||
|
}
|
||||||
|
if gt.RemoteConn != nil {
|
||||||
defer gt.RemoteConn.Close()
|
defer gt.RemoteConn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
logger.L().Error("dial remote failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
|
logger.L().Error("dial remote failed", zap.String("sessionId", gt.SessionId), zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if isConnectable {
|
||||||
|
return
|
||||||
|
}
|
||||||
go io.Copy(gt.LocalConn, gt.RemoteConn)
|
go io.Copy(gt.LocalConn, gt.RemoteConn)
|
||||||
go io.Copy(gt.RemoteConn, gt.LocalConn)
|
go io.Copy(gt.RemoteConn, gt.LocalConn)
|
||||||
|
|
||||||
@@ -77,13 +87,13 @@ func (gt *GatewayTunnel) Open() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GateWayManager struct {
|
type GateWayManager struct {
|
||||||
gateways map[string]*GatewayTunnel
|
gatewayTunnels map[string]*GatewayTunnel
|
||||||
sshClients map[int]*ssh.Client
|
sshClients map[int]*ssh.Client
|
||||||
sshClientsCount map[int]int
|
sshClientsCount map[int]int
|
||||||
mtx sync.Mutex
|
mtx sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) {
|
func (gm *GateWayManager) Open(isConnectable bool, sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) {
|
||||||
if gateway == nil {
|
if gateway == nil {
|
||||||
err = fmt.Errorf("gateway is nil")
|
err = fmt.Errorf("gateway is nil")
|
||||||
return
|
return
|
||||||
@@ -109,7 +119,7 @@ func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gatew
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
logger.L().Debug("ssh client closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait()))
|
logger.L().Debug("ssh proxy wait closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait()))
|
||||||
delete(gm.sshClients, gateway.Id)
|
delete(gm.sshClients, gateway.Id)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -133,8 +143,8 @@ func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gatew
|
|||||||
RemotePort: remotePort,
|
RemotePort: remotePort,
|
||||||
Opened: make(chan error),
|
Opened: make(chan error),
|
||||||
}
|
}
|
||||||
gm.gateways[sessionId] = g
|
gm.gatewayTunnels[sessionId] = g
|
||||||
go g.Open()
|
go g.Open(isConnectable)
|
||||||
|
|
||||||
logger.L().Debug("opening gateway", zap.Any("sessionId", sessionId))
|
logger.L().Debug("opening gateway", zap.Any("sessionId", sessionId))
|
||||||
<-g.Opened
|
<-g.Opened
|
||||||
@@ -147,13 +157,15 @@ func (gm *GateWayManager) Close(sessionIds ...string) {
|
|||||||
gm.mtx.Lock()
|
gm.mtx.Lock()
|
||||||
defer gm.mtx.Unlock()
|
defer gm.mtx.Unlock()
|
||||||
for _, sid := range sessionIds {
|
for _, sid := range sessionIds {
|
||||||
gt, ok := gm.gateways[sid]
|
gt, ok := gm.gatewayTunnels[sid]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
gm.sshClientsCount[gt.GatewayId] -= 1
|
gm.sshClientsCount[gt.GatewayId] -= 1
|
||||||
if gm.sshClientsCount[gt.GatewayId] <= 0 {
|
if gm.sshClientsCount[gt.GatewayId] <= 0 {
|
||||||
gm.sshClients[gt.GatewayId].Close()
|
if g := gm.sshClients[gt.GatewayId]; g != nil {
|
||||||
|
g.Close()
|
||||||
|
}
|
||||||
delete(gm.sshClients, gt.GatewayId)
|
delete(gm.sshClients, gt.GatewayId)
|
||||||
delete(gm.sshClientsCount, gt.GatewayId)
|
delete(gm.sshClientsCount, gt.GatewayId)
|
||||||
}
|
}
|
||||||
|
@@ -8,7 +8,6 @@ import (
|
|||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/spf13/cast"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
mysql "github.com/veops/oneterm/db"
|
mysql "github.com/veops/oneterm/db"
|
||||||
@@ -80,31 +79,28 @@ func UpdateConnectables(ids ...int) (err error) {
|
|||||||
|
|
||||||
func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) {
|
func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) {
|
||||||
sid = uuid.New().String()
|
sid = uuid.New().String()
|
||||||
for _, p := range asset.Protocols {
|
ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string { return strings.Split(p, ":")[0] }), ",")
|
||||||
ip, port := asset.Ip, cast.ToInt(strings.Split(p, ":")[1])
|
ip, port, err := util.Proxy(true, sid, ps, asset, gateway)
|
||||||
var (
|
|
||||||
gt *ggateway.GatewayTunnel
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if asset.GatewayId != 0 {
|
|
||||||
gt, err = ggateway.GetGatewayManager().Open(sid, ip, port, gateway)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L().Debug("open gateway failed", zap.Error(err))
|
logger.L().Debug("connectable proxy failed", zap.String("protocol", ps), zap.Error(err))
|
||||||
continue
|
return
|
||||||
}
|
|
||||||
ip, port = gt.LocalIp, gt.LocalPort
|
|
||||||
<-gt.Opened
|
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf("%s:%d", ip, port)
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||||
net, err := net.DialTimeout("tcp", addr, time.Second*3)
|
conn, err := net.DialTimeout("tcp", addr, time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.L().Debug("dail failed", zap.String("addr", addr), zap.Error(err))
|
logger.L().Debug("dail failed", zap.String("addr", addr), zap.Error(err))
|
||||||
continue
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
if asset.GatewayId != 0 {
|
||||||
|
t := ggateway.GetGatewayTunnelBySessionId(sid)
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = <-t.Opened; err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
defer net.Close()
|
|
||||||
|
|
||||||
ok = true
|
ok = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
@@ -58,7 +58,7 @@ func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Proxy(sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) {
|
func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) {
|
||||||
ip, port = asset.Ip, 0
|
ip, port = asset.Ip, 0
|
||||||
for _, tp := range strings.Split(protocol, ",") {
|
for _, tp := range strings.Split(protocol, ",") {
|
||||||
for _, p := range asset.Protocols {
|
for _, p := range asset.Protocols {
|
||||||
@@ -74,8 +74,7 @@ func Proxy(sessionId string, protocol string, asset *model.Asset, gateway *model
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
g, err := ggateway.GetGatewayManager().Open(isConnectable, sessionId, ip, port, gateway)
|
||||||
g, err := ggateway.GetGatewayManager().Open(sessionId, ip, port, gateway)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user