refactor: remove cte

This commit is contained in:
ttk
2024-03-08 09:09:22 +08:00
parent ebdd4b55dc
commit f0f8312056
6 changed files with 282 additions and 235 deletions

View File

@@ -2,7 +2,6 @@ package controller
import (
"errors"
"fmt"
"net/http"
"strings"
@@ -19,103 +18,9 @@ import (
)
var (
nodePreHooks = []preHook[*model.Node]{
func(ctx *gin.Context, data *model.Node) {
ids := make([]int, 0)
if err := mysql.DB.Raw(fmt.Sprintf(`
WITH RECURSIVE cte AS(
SELECT id
FROM node
WHERE id=%s AND deleted_at = 0
UNION ALL
SELECT t.id
FROM cte
INNER JOIN node t on cte.id = t.parent_id
WHERE deleted_at = 0
)
SELECT
id
FROM cte
`, ctx.Param("id"))).
Find(&ids).
Error; err != nil || lo.Contains(ids, data.ParentId) {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument})
}
},
}
nodePostHooks = []postHook[*model.Node]{
func(ctx *gin.Context, data []*model.Node) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
isAdmin := acl.IsAdmin(currentUser)
post := make([]*model.NodeCount, 0)
sql := fmt.Sprintf(`
WITH RECURSIVE cte AS(
SELECT parent_id
FROM asset
%s
UNION ALL
SELECT t.parent_id
FROM cte
INNER JOIN node t on cte.parent_id = t.id
WHERE deleted_at = 0
)
SELECT
parent_id,
COUNT(*) AS count
FROM cte
GROUP BY parent_id
`, lo.Ternary(isAdmin, "WHERE deleted_at = 0", "WHERE deleted_at = 0 AND id IN (?)"))
db := mysql.DB.
Model(&model.Asset{})
if isAdmin {
db = db.Raw(sql)
} else {
authorizationResourceIds, err := GetAutorizationResourceIds(ctx)
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
db = db.Raw(sql, mysql.DB.Model(&model.Authorization{}).Select("asset_id").Where("resource_id IN ?", authorizationResourceIds))
}
if err := db.
Find(&post).
Error; err != nil {
logger.L.Error("node posthookfailed asset count", zap.Error(err))
return
}
m := lo.SliceToMap(post, func(p *model.NodeCount) (int, int64) { return p.ParentId, p.Count })
for _, d := range data {
d.AssetCount = m[d.Id]
}
}, func(ctx *gin.Context, data []*model.Node) {
ps := make([]int, 0)
if err := mysql.DB.
Model(&model.Node{}).
Where("parent_id IN ?", lo.Map(data, func(n *model.Node, _ int) int { return n.Id })).
Pluck("parent_id", &ps).
Error; err != nil {
logger.L.Error("node posthookfailed has child", zap.Error(err))
return
}
pm := lo.SliceToMap(ps, func(pid int) (int, bool) { return pid, true })
for _, n := range data {
n.HasChild = pm[n.Id]
}
},
}
nodeDcs = []deleteCheck{
func(ctx *gin.Context, id int) {
noChild := true
noChild = noChild && errors.Is(mysql.DB.Model(&model.Node{}).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound)
noChild = noChild && errors.Is(mysql.DB.Model(&model.Asset{}).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound)
if noChild {
return
}
err := &ApiError{Code: ErrHasChild, Data: nil}
ctx.AbortWithError(http.StatusBadRequest, err)
},
}
nodePreHooks = []preHook[*model.Node]{nodePreHookCheckCycle}
nodePostHooks = []postHook[*model.Node]{nodePostHookCountAsset, nodePostHookHasChild}
nodeDcs = []deleteCheck{nodeDelHook}
)
// CreateNode godoc
@@ -172,46 +77,151 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","), func(s string, _ int) int { return cast.ToInt(s) }))
}
if id, ok := ctx.GetQuery("no_self_child"); ok {
sql := fmt.Sprintf(`
WITH RECURSIVE cte AS(
SELECT id
FROM node
WHERE id=%s AND deleted_at = 0
UNION ALL
SELECT t.id
FROM cte
INNER JOIN node t on cte.id = t.parent_id
WHERE deleted_at = 0
)
SELECT
id
FROM cte
`, id)
sub := mysql.DB.Raw(sql)
db = db.Where("id NOT IN (?)", sub)
ids, err := handleNoSelfChild(cast.ToInt(id))
if err != nil {
return
}
db = db.Where("id NOT IN ?", ids)
}
if id, ok := ctx.GetQuery("self_parent"); ok {
sql := fmt.Sprintf(`
WITH RECURSIVE cte AS(
SELECT id,parent_id
FROM node
WHERE id=%s AND deleted_at = 0
UNION ALL
SELECT t.id,t.parent_id
FROM cte
INNER JOIN node t on cte.parent_id = t.id
WHERE deleted_at = 0
)
SELECT
id
FROM cte
`, id)
sub := mysql.DB.Raw(sql)
db = db.Where("id IN (?)", sub)
ids, err := handleSelfParent(cast.ToInt(id))
if err != nil {
return
}
db = db.Where("id IN ?", ids)
}
db = db.Order("name DESC")
doGet[*model.Node](ctx, false, db, "", nodePostHooks...)
}
func nodePreHookCheckCycle(ctx *gin.Context, data *model.Node) {
nodes := make([]*model.NodeIdPid, 0)
err := mysql.DB.Model(nodes).Find(&nodes).Error
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int) bool
dfs = func(x int) bool {
b := x == data.ParentId
for _, y := range g[x] {
b = b || dfs(y)
}
return b
}
if err != nil || dfs(cast.ToInt(ctx.Param("id"))) {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrInvalidArgument})
}
}
func nodePostHookCountAsset(ctx *gin.Context, data []*model.Node) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
isAdmin := acl.IsAdmin(currentUser)
assets := make([]*model.AssetIdPid, 0)
db := mysql.DB.Model(assets)
if !isAdmin {
authorizationResourceIds, err := GetAutorizationResourceIds(ctx)
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
db = db.Where("resource_id IN ?", authorizationResourceIds)
}
if err := db.Find(&assets).Error; err != nil {
logger.L.Error("node posthookfailed asset count", zap.Error(err))
return
}
m := make(map[int]int64)
g := make(map[int][]int)
for _, n := range assets {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int) int64
dfs = func(x int) int64 {
m[x] += 1
for _, y := range g[x] {
m[x] += dfs(y)
}
return m[x]
}
for _, d := range data {
d.AssetCount = m[d.Id]
}
}
func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) {
ps := make([]int, 0)
if err := mysql.DB.
Model(&model.Node{}).
Where("parent_id IN ?", lo.Map(data, func(n *model.Node, _ int) int { return n.Id })).
Pluck("parent_id", &ps).
Error; err != nil {
logger.L.Error("node posthookfailed has child", zap.Error(err))
return
}
pm := lo.SliceToMap(ps, func(pid int) (int, bool) { return pid, true })
for _, n := range data {
n.HasChild = pm[n.Id]
}
}
func nodeDelHook(ctx *gin.Context, id int) {
noChild := true
noChild = noChild && errors.Is(mysql.DB.Model(&model.Node{}).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound)
noChild = noChild && errors.Is(mysql.DB.Model(&model.Asset{}).Select("id").Where("parent_id = ?", id).First(map[string]any{}).Error, gorm.ErrRecordNotFound)
if noChild {
return
}
err := &ApiError{Code: ErrHasChild, Data: nil}
ctx.AbortWithError(http.StatusBadRequest, err)
}
func handleNoSelfChild(id int) (ids []int, err error) {
nodes := make([]*model.NodeIdPid, 0)
if err = mysql.DB.Model(nodes).Find(&nodes).Error; err != nil {
return
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int)
dfs = func(x int) {
ids = append(ids, x)
for _, y := range g[x] {
dfs(y)
}
}
dfs(id)
return
}
func handleSelfParent(id int) (ids []int, err error) {
nodes := make([]*model.NodeIdPid, 0)
if err = mysql.DB.Model(nodes).Find(&nodes).Error; err != nil {
return
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int)
dfs = func(x int) {
ids = append(ids, x)
for _, y := range g[x] {
dfs(y)
}
}
dfs(id)
ids = append(lo.Without(lo.Keys(g), ids...), id)
return
}