mirror of
				https://github.com/veops/oneterm.git
				synced 2025-11-01 03:12:39 +08:00 
			
		
		
		
	fix(api): connectable
This commit is contained in:
		| @@ -109,13 +109,17 @@ func (c *Controller) GetAssets(ctx *gin.Context) { | ||||
| 		db = db.Where("parent_id IN ?", parentIds) | ||||
| 	} | ||||
|  | ||||
| 	if info && !acl.IsAdmin(currentUser) { | ||||
| 		ids, err := GetAssetIdsByAuthorization(ctx) | ||||
| 		if err != nil { | ||||
| 			ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) | ||||
| 			return | ||||
| 	if info { | ||||
| 		db = db.Select("id", "parent_id", "name", "ip", "protocols", "connectable", "authorization") | ||||
|  | ||||
| 		if !acl.IsAdmin(currentUser) { | ||||
| 			ids, err := GetAssetIdsByAuthorization(ctx) | ||||
| 			if err != nil { | ||||
| 				ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}}) | ||||
| 				return | ||||
| 			} | ||||
| 			db = db.Where("id IN ?", ids) | ||||
| 		} | ||||
| 		db = db.Where("id IN ?", ids) | ||||
| 	} | ||||
|  | ||||
| 	db = db.Order("name") | ||||
|   | ||||
| @@ -346,7 +346,7 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	ip, port, err := util.Proxy(sess.SessionId, "ssh", asset, gateway) | ||||
| 	ip, port, err := util.Proxy(false, sess.SessionId, "ssh", asset, gateway) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -266,8 +266,8 @@ func doUpdate[T model.Model](ctx *gin.Context, needAcl bool, md T, resourceType | ||||
| 	} | ||||
| 	if needAcl { | ||||
| 		md.SetResourceId(old.GetResourceId()) | ||||
| 		fmt.Printf("%+v\n", old) | ||||
| 		fmt.Printf("%+v\n", md) | ||||
| 		// fmt.Printf("%+v\n", old) | ||||
| 		// fmt.Printf("%+v\n", md) | ||||
| 		if !hasPerm(ctx, md, resourceType, acl.WRITE) { | ||||
| 			ctx.AbortWithError(http.StatusForbidden, &ApiError{Code: ErrNoPerm, Data: map[string]any{"perm": acl.WRITE}}) | ||||
| 			return | ||||
|   | ||||
| @@ -98,7 +98,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) { | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		db = db.Where("id NOT IN ?", ids) | ||||
| 		db = db.Where("id IN ?", ids) | ||||
| 	} | ||||
|  | ||||
| 	if id, ok := ctx.GetQuery("self_parent"); ok { | ||||
| @@ -109,15 +109,18 @@ func (c *Controller) GetNodes(ctx *gin.Context) { | ||||
| 		db = db.Where("id IN ?", ids) | ||||
| 	} | ||||
|  | ||||
| 	if info && !acl.IsAdmin(currentUser) { | ||||
| 		ids, err := GetNodeIdsByAuthorization(ctx) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 	if info { | ||||
| 		db = db.Select("id", "parent_id", "name") | ||||
| 		if !acl.IsAdmin(currentUser) { | ||||
| 			ids, err := GetNodeIdsByAuthorization(ctx) | ||||
| 			if err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			if ids, err = handleSelfParent(ctx, ids...); err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			db = db.Where("id IN ?", ids) | ||||
| 		} | ||||
| 		if ids, err = handleSelfParent(ctx, ids...); err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		db = db.Where("id IN ?", ids) | ||||
| 	} | ||||
|  | ||||
| 	doGet(ctx, !info, db, conf.RESOURCE_NODE, nodePostHooks...) | ||||
| @@ -253,7 +256,7 @@ func nodeDelHook(ctx *gin.Context, id int) { | ||||
| 	ctx.AbortWithError(http.StatusBadRequest, err) | ||||
| } | ||||
|  | ||||
| func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| 	nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| @@ -274,10 +277,12 @@ func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| 	} | ||||
| 	dfs(0, false) | ||||
|  | ||||
| 	res = lo.Uniq(append(res, ids...)) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| 	nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| @@ -290,10 +295,10 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) | ||||
| 	t := make([]int, 0) | ||||
| 	var dfs func(int) | ||||
| 	dfs = func(x int) { | ||||
| 		t = append(t, x) | ||||
| 		if lo.Contains(ids, x) { | ||||
| 			res = append(res, t...) | ||||
| 		} | ||||
| 		t = append(t, x) | ||||
| 		for _, y := range g[x] { | ||||
| 			dfs(y) | ||||
| 		} | ||||
| @@ -301,27 +306,39 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) | ||||
| 	} | ||||
| 	dfs(0) | ||||
|  | ||||
| 	res = lo.Uniq(res) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| 	res, err = handleNoSelfParent(ctx, ids...) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	res = lo.Uniq(append(res, ids...)) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| 	res, err = handleNoSelfChild(ctx, ids...) | ||||
| func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| 	nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	res = lo.Uniq(append(res, ids...)) | ||||
| 	allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id }) | ||||
|  | ||||
| 	res, err = handleSelfChild(ctx, ids...) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	res = lo.Uniq(lo.Without(allids, res...)) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error) { | ||||
| 	nodes, err := util.GetAllFromCacheDb(ctx, model.DefaultNode) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id }) | ||||
|  | ||||
| 	res, err = handleSelfParent(ctx, ids...) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	res = lo.Uniq(lo.Without(allids, res...)) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -75,7 +75,7 @@ func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client, | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	ip, port, err := util.Proxy(uuid.New().String(), "sftp,ssh", asset, gateway) | ||||
| 	ip, port, err := util.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
| @@ -92,7 +92,7 @@ func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, a | ||||
| 		t.Config.Parameters["recording-name"] = t.SessionId | ||||
| 	} | ||||
| 	if gateway != nil && gateway.Id != 0 && t.ConnectionId == "" { | ||||
| 		t.gw, err = ggateway.GetGatewayManager().Open(t.SessionId, asset.Ip, cast.ToInt(port), gateway) | ||||
| 		t.gw, err = ggateway.GetGatewayManager().Open(false, t.SessionId, asset.Ip, cast.ToInt(port), gateway) | ||||
| 		if err != nil { | ||||
| 			return t, err | ||||
| 		} | ||||
|   | ||||
| @@ -17,7 +17,7 @@ import ( | ||||
|  | ||||
| var ( | ||||
| 	manager = &GateWayManager{ | ||||
| 		gateways:        map[string]*GatewayTunnel{}, | ||||
| 		gatewayTunnels:  map[string]*GatewayTunnel{}, | ||||
| 		sshClients:      map[int]*ssh.Client{}, | ||||
| 		sshClientsCount: map[int]int{}, | ||||
| 		mtx:             sync.Mutex{}, | ||||
| @@ -28,8 +28,8 @@ func GetGatewayManager() *GateWayManager { | ||||
| 	return manager | ||||
| } | ||||
|  | ||||
| func GetGatewayBySessionId(sessionId string) *GatewayTunnel { | ||||
| 	return manager.gateways[sessionId] | ||||
| func GetGatewayTunnelBySessionId(sessionId string) *GatewayTunnel { | ||||
| 	return manager.gatewayTunnels[sessionId] | ||||
| } | ||||
|  | ||||
| type GatewayTunnel struct { | ||||
| @@ -45,13 +45,14 @@ type GatewayTunnel struct { | ||||
| 	Opened     chan error | ||||
| } | ||||
|  | ||||
| func (gt *GatewayTunnel) Open() (err error) { | ||||
| func (gt *GatewayTunnel) Open(isConnectable bool) (err error) { | ||||
| 	go func() { | ||||
| 		<-time.After(time.Second * 3) | ||||
| 		logger.L().Debug("timeout 3 second close listener", zap.String("sessionId", gt.SessionId)) | ||||
| 		gt.listener.Close() | ||||
| 	}() | ||||
| 	defer func() { | ||||
| 		logger.L().Debug("close listener", zap.String("sessionId", gt.SessionId), zap.Error(err)) | ||||
| 		gt.Opened <- err | ||||
| 	}() | ||||
| 	gt.Opened <- nil | ||||
| @@ -65,11 +66,20 @@ func (gt *GatewayTunnel) Open() (err error) { | ||||
| 	defer cancel() | ||||
| 	gt.RemoteConn, err = manager.sshClients[gt.GatewayId].DialContext(ctx, "tcp", remoteAddr) | ||||
| 	if err != nil { | ||||
| 		defer gt.LocalConn.Close() | ||||
| 		defer gt.RemoteConn.Close() | ||||
| 		defer func() { | ||||
| 			if gt.LocalConn != nil { | ||||
| 				defer gt.LocalConn.Close() | ||||
| 			} | ||||
| 			if gt.RemoteConn != nil { | ||||
| 				defer gt.RemoteConn.Close() | ||||
| 			} | ||||
| 		}() | ||||
| 		logger.L().Error("dial remote failed", zap.String("sessionId", gt.SessionId), zap.Error(err)) | ||||
| 		return | ||||
| 	} | ||||
| 	if isConnectable { | ||||
| 		return | ||||
| 	} | ||||
| 	go io.Copy(gt.LocalConn, gt.RemoteConn) | ||||
| 	go io.Copy(gt.RemoteConn, gt.LocalConn) | ||||
|  | ||||
| @@ -77,13 +87,13 @@ func (gt *GatewayTunnel) Open() (err error) { | ||||
| } | ||||
|  | ||||
| type GateWayManager struct { | ||||
| 	gateways        map[string]*GatewayTunnel | ||||
| 	gatewayTunnels  map[string]*GatewayTunnel | ||||
| 	sshClients      map[int]*ssh.Client | ||||
| 	sshClientsCount map[int]int | ||||
| 	mtx             sync.Mutex | ||||
| } | ||||
|  | ||||
| func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) { | ||||
| func (gm *GateWayManager) Open(isConnectable bool, sessionId, remoteIp string, remotePort int, gateway *model.Gateway) (g *GatewayTunnel, err error) { | ||||
| 	if gateway == nil { | ||||
| 		err = fmt.Errorf("gateway is nil") | ||||
| 		return | ||||
| @@ -109,7 +119,7 @@ func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gatew | ||||
| 			return | ||||
| 		} | ||||
| 		go func() { | ||||
| 			logger.L().Debug("ssh client closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait())) | ||||
| 			logger.L().Debug("ssh proxy wait closed", zap.Int("gatewayId", gateway.Id), zap.Error(sshCli.Wait())) | ||||
| 			delete(gm.sshClients, gateway.Id) | ||||
| 		}() | ||||
| 	} | ||||
| @@ -133,8 +143,8 @@ func (gm *GateWayManager) Open(sessionId, remoteIp string, remotePort int, gatew | ||||
| 		RemotePort: remotePort, | ||||
| 		Opened:     make(chan error), | ||||
| 	} | ||||
| 	gm.gateways[sessionId] = g | ||||
| 	go g.Open() | ||||
| 	gm.gatewayTunnels[sessionId] = g | ||||
| 	go g.Open(isConnectable) | ||||
|  | ||||
| 	logger.L().Debug("opening gateway", zap.Any("sessionId", sessionId)) | ||||
| 	<-g.Opened | ||||
| @@ -147,13 +157,15 @@ func (gm *GateWayManager) Close(sessionIds ...string) { | ||||
| 	gm.mtx.Lock() | ||||
| 	defer gm.mtx.Unlock() | ||||
| 	for _, sid := range sessionIds { | ||||
| 		gt, ok := gm.gateways[sid] | ||||
| 		gt, ok := gm.gatewayTunnels[sid] | ||||
| 		if !ok { | ||||
| 			return | ||||
| 		} | ||||
| 		gm.sshClientsCount[gt.GatewayId] -= 1 | ||||
| 		if gm.sshClientsCount[gt.GatewayId] <= 0 { | ||||
| 			gm.sshClients[gt.GatewayId].Close() | ||||
| 			if g := gm.sshClients[gt.GatewayId]; g != nil { | ||||
| 				g.Close() | ||||
| 			} | ||||
| 			delete(gm.sshClients, gt.GatewayId) | ||||
| 			delete(gm.sshClientsCount, gt.GatewayId) | ||||
| 		} | ||||
|   | ||||
| @@ -8,7 +8,6 @@ import ( | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/samber/lo" | ||||
| 	"github.com/spf13/cast" | ||||
| 	"go.uber.org/zap" | ||||
|  | ||||
| 	mysql "github.com/veops/oneterm/db" | ||||
| @@ -80,31 +79,28 @@ func UpdateConnectables(ids ...int) (err error) { | ||||
|  | ||||
| func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) { | ||||
| 	sid = uuid.New().String() | ||||
| 	for _, p := range asset.Protocols { | ||||
| 		ip, port := asset.Ip, cast.ToInt(strings.Split(p, ":")[1]) | ||||
| 		var ( | ||||
| 			gt  *ggateway.GatewayTunnel | ||||
| 			err error | ||||
| 		) | ||||
| 		if asset.GatewayId != 0 { | ||||
| 			gt, err = ggateway.GetGatewayManager().Open(sid, ip, port, gateway) | ||||
| 			if err != nil { | ||||
| 				logger.L().Debug("open gateway failed", zap.Error(err)) | ||||
| 				continue | ||||
| 			} | ||||
| 			ip, port = gt.LocalIp, gt.LocalPort | ||||
| 			<-gt.Opened | ||||
| 		} | ||||
| 		addr := fmt.Sprintf("%s:%d", ip, port) | ||||
| 		net, err := net.DialTimeout("tcp", addr, time.Second*3) | ||||
| 		if err != nil { | ||||
| 			logger.L().Debug("dail failed", zap.String("addr", addr), zap.Error(err)) | ||||
| 			continue | ||||
| 		} | ||||
| 		defer net.Close() | ||||
|  | ||||
| 		ok = true | ||||
| 	ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string { return strings.Split(p, ":")[0] }), ",") | ||||
| 	ip, port, err := util.Proxy(true, sid, ps, asset, gateway) | ||||
| 	if err != nil { | ||||
| 		logger.L().Debug("connectable proxy failed", zap.String("protocol", ps), zap.Error(err)) | ||||
| 		return | ||||
| 	} | ||||
| 	addr := fmt.Sprintf("%s:%d", ip, port) | ||||
| 	conn, err := net.DialTimeout("tcp", addr, time.Second) | ||||
| 	if err != nil { | ||||
| 		logger.L().Debug("dail failed", zap.String("addr", addr), zap.Error(err)) | ||||
| 		return | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 	if asset.GatewayId != 0 { | ||||
| 		t := ggateway.GetGatewayTunnelBySessionId(sid) | ||||
| 		if t == nil { | ||||
| 			return | ||||
| 		} | ||||
| 		if err = <-t.Opened; err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	ok = true | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -58,7 +58,7 @@ func GetAuth(account *model.Account) (ssh.AuthMethod, error) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Proxy(sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) { | ||||
| func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) { | ||||
| 	ip, port = asset.Ip, 0 | ||||
| 	for _, tp := range strings.Split(protocol, ",") { | ||||
| 		for _, p := range asset.Protocols { | ||||
| @@ -74,8 +74,7 @@ func Proxy(sessionId string, protocol string, asset *model.Asset, gateway *model | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
|  | ||||
| 	g, err := ggateway.GetGatewayManager().Open(sessionId, ip, port, gateway) | ||||
| 	g, err := ggateway.GetGatewayManager().Open(isConnectable, sessionId, ip, port, gateway) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 ttk
					ttk