feat: 新增curd工具

This commit is contained in:
dabai
2022-08-15 14:25:02 +08:00
7 changed files with 556 additions and 17 deletions

8
config/database.yaml Normal file
View File

@@ -0,0 +1,8 @@
connections:
mysql:
driver: mysql
host: env("DB_HOST")
port: env("DB_PORT")
database: env("DB_DATABASE")
username: env("DB_USERNAME")
password: env("DB_PASSWORD")

513
console/commands/curd.go Normal file
View File

@@ -0,0 +1,513 @@
package commands
import (
"database/sql"
"fmt"
"github.com/ctfang/command"
"github.com/go-home-admin/toolset/console/commands/orm"
"github.com/go-home-admin/toolset/console/commands/pgorm"
"github.com/go-home-admin/toolset/parser"
"github.com/joho/godotenv"
"gopkg.in/yaml.v2"
"log"
"os"
"strconv"
"strings"
)
// CurdCommand @Bean
type CurdCommand struct{}
type TableColumn struct {
Name string
GoType string
Comment string
}
type Param struct {
CoonName string
TableName string
Module string
Explain string
DbName string
}
var param Param
func (CurdCommand) Configure() command.Configure {
return command.Configure{
Name: "make:curd",
Description: "生成curd基础代码, 默认使用交互输入, 便捷调用 ",
Input: command.Argument{
Option: []command.ArgParam{
{
Name: "conn_name",
Description: "连接名",
},
{
Name: "table_name",
Description: "表名",
},
{
Name: "config",
Description: "配置文件",
Default: "@root/config/database.yaml",
},
{
Name: "module",
Description: "模块名称, 默认: admin",
Default: "admin",
},
{
Name: "explain",
Description: "生成的注释, 默认为表注释",
Call: func(val string, c *command.Console) (string, bool) {
return "", true
},
},
},
},
}
}
func (CurdCommand) Execute(input command.Input) {
root := getRootPath()
err := godotenv.Load(root + "/.env")
if err != nil {
fmt.Println(root + "/.env" + "文件不存在, 无法加载环境变量")
}
file := input.GetOption("config")
file = strings.Replace(file, "@root", root, 1)
fileContext, _ := os.ReadFile(file)
fileContext = SetEnv(fileContext)
m := make(map[string]interface{})
err = yaml.Unmarshal(fileContext, &m)
if err != nil {
log.Printf("配置解析错误:%v", err)
return
}
connections := m["connections"].(map[interface{}]interface{})
param.CoonName = getConnName(input.GetOption("conn_name"), connections)
param.TableName = getTableName(input.GetOption("table_name"), connections[param.CoonName])
config := connections[param.CoonName].(map[interface{}]interface{})
TableColumns := GetTableColumn(config, param.TableName)
module := input.GetOption("module")
outUrl := root + "/app/http/" + module + "/" + param.TableName
_, err = os.Stat(outUrl)
if os.IsNotExist(err) {
err = os.MkdirAll(outUrl, 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
protoUrl := root + "/protobuf/" + module + "/" + param.TableName
_, err = os.Stat(protoUrl)
if os.IsNotExist(err) {
err = os.MkdirAll(protoUrl, 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
index := strings.Index(param.TableName, "_")
contName := ""
if index >= 0 {
other := param.TableName[index+1:]
contName = strings.ToUpper(param.TableName[:1]) + param.TableName[1:index] + strings.ToUpper(other[:1]) + other[1:]
} else {
contName = strings.ToUpper(param.TableName)
}
goMod := getModModule()
//controller
buildController(param, outUrl, goMod, contName)
//del
buildDel(param, outUrl, goMod, "del", contName)
//get
buildGet(param, outUrl, goMod, "get", contName, TableColumns)
//post
buildPost(param, outUrl, goMod, "post", contName, TableColumns)
//put
buildPut(param, outUrl, goMod, "put", contName, TableColumns)
//proto
buildProto(param, protoUrl, goMod, contName, TableColumns)
}
// 获取连接名称
func getConnName(connName string, connections map[interface{}]interface{}) string {
if connName == "" {
var got int
gotName := make(map[int]string)
fmt.Printf("请选中以下连接数据库配置\n")
for name, m := range connections {
conf := m.(map[interface{}]interface{})
driver := conf["driver"]
if driver == "mysql" || driver == "pgsql" {
got++
gotName[got] = name.(string)
fmt.Printf("%v: %v\n", got, name)
}
}
if len(gotName) == 1 {
got = 1
fmt.Printf("只有一个数据库, 已经自动选中: 1\n")
} else {
fmt.Printf("请输入数字: ")
fmt.Scan(&got)
}
connName = gotName[got]
}
if _, ok := connections[connName]; !ok {
panic("没有找不到对应数据库连接")
}
return connName
}
func getTableName(tableName string, m interface{}) string {
conf := m.(map[interface{}]interface{})
if tableName == "" {
fmt.Printf("未指定表, 可以使用以下的表生成\n")
tables := make(map[int]string)
switch conf["driver"] {
case "mysql":
db := orm.NewDb(conf)
rows, _ := db.GetDB().Query("SELECT A.TABLE_NAME as name FROM INFORMATION_SCHEMA.COLUMNS A WHERE A.TABLE_SCHEMA = ? GROUP BY TABLE_NAME", conf["database"].(string))
defer rows.Close()
i := 0
for rows.Next() {
i++
var name string
rows.Scan(&name)
tables[i] = name
fmt.Printf("%v: %v\n", i, name)
}
case "pgsql":
}
fmt.Printf("请输入数字: ")
var got int
fmt.Scan(&got)
tableName = tables[got]
}
return tableName
}
func buildController(param Param, outUrl string, module string, contName string) {
cont := outUrl + "/" + param.TableName + "_controller.go"
str := "package " + param.TableName
str += "\n\n// " + param.Explain
str += "\ntype Controller struct {}"
err := os.WriteFile(cont, []byte(str), 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
func buildHead(tableName string, module string, name string) string {
str := "package " + tableName
tm := []string{
"github.com/gin-gonic/gin",
module + "/app/common/auth",
module + "/app/entity/" + param.CoonName,
module + "/app/providers",
module + "/generate/proto/admin",
module + "/home/app/http",
}
str += "\n\nimport ("
for _, v := range tm {
str += "\n \"" + v + "\""
}
str += "\n)"
return str
}
func buildDel(param Param, outUrl string, module string, name string, contName string) {
cont := outUrl + "/del_action.go"
str := buildHead(param.TableName, module, name)
str += "\n\n// Del 删除数据 - " + param.Explain
str += "\nfunc (receiver *Controller) Del(req *admin" + "." + contName + "PutRequest, ctx *auth.Context) (*admin" + "." + contName + "PutRequest, error) {"
str += "\n id := ctx.GetId()"
str += "\n err := " + param.CoonName + ".NewOrm" + contName + ".Delete(id)"
str += "\n return &admin" + "." + contName + "PutRequest{"
str += "\n Tip: \"OK\","
str += "\n }, err.Error"
str += "\n}"
str += "\n"
str += handleValue(name, "admin", contName)
err := os.WriteFile(cont, []byte(str), 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
func buildGet(param Param, outUrl string, module string, name string, contName string, column []TableColumn) {
cont := outUrl + "/get_action.go"
str := buildHead(param.TableName, module, name)
str += "\n\n// Get 列表查询 - " + param.Explain
co := ""
for _, v := range column {
co += fmt.Sprintf(`
%v: cp.%v,`,
v.Name, v.Name)
}
str += fmt.Sprintf(`
func (receiver *Controller) Get(req *%v.%vGetRequest, ctx *auth.Context) (*%v.%vGetResponse, error) {
list, total := %v.NewOrm%v().Paginate(int(req.Page), int(req.Limit))
responseList := make([]*%v.%vInfo, 0)
for _, cp := range list {
responseList = append(responseList, &%v.%vInfo{%v
})
}
return &%v.%vGetResponse{
List: responseList,
Total: uint32(total),
}, nil
}
`,
"admin",
contName,
"admin",
contName,
param.CoonName,
contName,
"admin",
contName,
"admin",
contName,
co,
"admin",
contName,
)
str += handleValue(name, "admin", contName)
err := os.WriteFile(cont, []byte(str), 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
func buildPost(param Param, outUrl string, module string, name string, contName string, column []TableColumn) {
cont := outUrl + "/post_action.go"
str := buildHead(param.TableName, module, name)
str += "\n\n// Post 创建新数据 - " + param.Explain
str += "\nfunc (receiver *Controller) Post(req *admin" + "." + contName + "PostRequest, ctx *auth.Context) (*admin" + "." + contName + "PostResponse, error) {"
str += "\n id := int32(common.GetParamId(ctx))"
str += "\n has := receiver.orm.WhereId(id).First()"
str += "\n if has == nil {"
str += "\n return nil, nil"
str += "\n }"
split := strings.Split(param.TableName, "_")
dbFunc := ""
for _, t := range split {
dbFunc += strings.ToUpper(t[:1]) + t[1:]
}
str += "\n data := " + param.CoonName + "." + dbFunc + "{"
for _, v := range column {
str += "\n " + v.Name + ": " + "cp." + v.Name + ","
}
str += "\n }"
str += "\n res := receiver.orm.Create(&data)"
str += "\n return &admin" + "." + contName + "PostResponse{}, res.Error"
str += "\n}"
str += handleValue(name, "admin", contName)
err := os.WriteFile(cont, []byte(str), 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
func buildPut(param Param, outUrl string, module string, name string, contName string, column []TableColumn) {
cont := outUrl + "/put_action.go"
str := buildHead(param.TableName, module, name)
str += "\n\n// Put 更新数据 - " + param.Explain
str += "\nfunc (receiver *Controller) Put(req *admin" + "." + contName + "PostRequest, ctx *auth.Context) (*admin" + "." + contName + "PostRequest, error) {"
str += "\n id := int32(common.GetParamId(ctx))"
str += "\n has := receiver.orm.WhereId(id).First()"
str += "\n if has == nil {"
str += "\n return nil, nil"
str += "\n }"
split := strings.Split(param.TableName, "_")
dbFunc := ""
for _, t := range split {
dbFunc += strings.ToUpper(t[:1]) + t[1:]
}
str += "\n err := receiver.orm.WhereId(id).Updates(&" + param.CoonName + "." + dbFunc + "{"
for _, v := range column {
str += "\n " + v.Name + ": " + "cp." + v.Name + ","
}
str += "})"
str += "\n return &admin" + "." + contName + "PutResponse{}, err.Error"
str += "\n}"
str += handleValue(name, "admin", contName)
err := os.WriteFile(cont, []byte(str), 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
func handleValue(name string, module string, contName string) string {
name = strings.ToUpper(name[:1]) + name[1:]
str := fmt.Sprintf(`
//GinHandle%v gin原始路由处理`, name)
str += fmt.Sprintf(`
func (receiver *Controller) GinHandle%v(ctx *gin.Context) {
req := &%v.%v%vRequest{}
err := ctx.ShouleBind(req)
if err != nil {
providers.ErrorRequest(ctx, err)
return
}
resp, err := receiver.%v(req, auth.NewContext(ctx))
if err != nil {
providers.ErrorResponse(ctx, err)
return
}
providers.SuccessResponse(ctx, resp)
}
`,
name,
module,
contName,
name,
name,
)
return str
}
func buildProto(param Param, protoUrl string, module string, contName string, column []TableColumn) {
cont := protoUrl + "/" + param.TableName + ".proto"
str := "// @Tag(\"form\");"
str += "\nsyntax = \"proto3\";"
str += "\n\npackage " + param.TableName + ";"
str += "\n\nimport \"http_config.proto\";"
str += "\n\noption go_package = \"" + module + "/generate/proto/admin\"" + ";"
str += "\n// " + param.Explain + "资源控制器"
str += "\nservice " + contName + "{"
str += "\n // 需要登录"
str += "\n option (http.RouteGroup) = \"login\";"
str += "\n // " + param.Explain + "列表"
str += "\n rpc Get(" + contName + "GetRequest) returns (" + contName + "GetResponse){"
str += "\n option (http.Get) = \"/admin" + "/" + param.TableName + "\";"
str += "\n }"
str += "\n // " + param.Explain + "创建"
str += "\n rpc Post(" + contName + "PostRequest) returns (" + contName + "PostResponse){"
str += "\n option (http.Post) = \"/admin" + "/" + param.TableName + "\";"
str += "\n }"
str += "\n // " + param.Explain + "更新"
str += "\n rpc Put(" + contName + "PutRequest) returns (" + contName + "PutResponse){"
str += "\n option (http.Put) = \"/admin" + "/" + param.TableName + "/:id\";"
str += "\n }"
str += "\n // " + param.Explain + "删除"
str += "\n rpc Del(" + contName + "PutRequest) returns (" + contName + "PutResponse){"
str += "\n option (http.Get) = \"/admin" + "/" + param.TableName + "/:id\";"
str += "\n }"
str += "\n}"
str += "\nmessage " + contName + "GetRequest {"
str += "\n // 列表第几页默认1"
str += "\n uint32 page = 1;"
str += "\n // 每页多少条数据"
str += "\n uint32 limit = 2;"
str += "\n}"
str += "\nmessage " + contName + "GetResponse {"
str += "\n // 数据列表"
str += "\n repeated " + contName + "Info list = 1;"
str += "\n // 符合条件总数量,计算多少页"
str += "\n uint32 total = 2;"
str += "\n}"
str += "\nmessage " + contName + "PostRequest {"
str += "\n}"
str += "\nmessage " + contName + "PostResponse {"
str += "\n // 提示语"
str += "\n string tip = 1;"
str += "\n}"
str += "\nmessage " + contName + "PutRequest {"
str += "\n"
str += "\n"
str += "\n}"
str += "\nmessage " + contName + "PutResponse {"
str += "\n // 提示语"
str += "\n string tip = 1;"
str += "\n}"
str += "\nmessage " + contName + "Info{"
for i, v := range column {
str += "\n // " + v.Comment
if v.GoType == "database.Time" {
v.GoType = "string"
}
str += "\n " + v.GoType + " " + v.Name + " = " + strconv.Itoa(i+1) + ";"
}
str += "\n}"
err := os.WriteFile(cont, []byte(str), 0766)
if err != nil {
log.Printf(err.Error())
return
}
}
func GetTableColumn(config map[interface{}]interface{}, tableName string) []TableColumn {
rows := &sql.Rows{}
var err error
switch config["driver"] {
case "mysql":
rows, err = orm.NewDb(config).GetDB().Query(`
SELECT COLUMN_NAME, DATA_TYPE, COLUMN_COMMENT
FROM information_schema.COLUMNS
WHERE table_schema = DATABASE () AND table_name = ?
ORDER BY ORDINAL_POSITION ASC`, tableName)
case "pgsql":
db := pgorm.NewDb(config)
rows, err = db.GetDB().Query(`
SELECT i.column_name, i.udt_name, col_description(a.attrelid,a.attnum) as comment
FROM information_schema.columns as i
LEFT JOIN pg_class as c on c.relname = i.table_name
LEFT JOIN pg_attribute as a on a.attrelid = c.oid and a.attname = i.column_name
WHERE table_schema = 'public' and i.table_name = $1;
`, tableName)
default:
panic("没有[" + config["driver"].(string) + "]的驱动")
}
defer rows.Close()
if err != nil {
panic("数据库连接失败或没有找到对应的表")
}
var tableColumns []TableColumn
for rows.Next() {
var name, dataType, comment string
var _comment *string
_ = rows.Scan(
&name,
&dataType,
&_comment,
)
if _comment == nil {
comment = ""
} else {
comment = *_comment
}
switch config["driver"] {
case "mysql":
dataType = orm.TypeForMysqlToGo[dataType]
case "pgsql":
dataType = pgorm.PgTypeToGoType(dataType, name)
}
tableColumns = append(tableColumns, TableColumn{
Name: parser.StringToHump(name),
GoType: dataType,
Comment: comment,
})
}
return tableColumns
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/go-home-admin/toolset/console/commands/pgorm"
"github.com/joho/godotenv"
"gopkg.in/yaml.v2"
"log"
"os"
"os/exec"
"strings"
@@ -52,10 +53,10 @@ func (OrmCommand) Execute(input command.Input) {
fileContext = SetEnv(fileContext)
m := make(map[string]interface{})
err = yaml.Unmarshal(fileContext, &m)
//if err != nil {
// panic(err)
//}
if err != nil {
log.Printf("配置解析错误:%v", err)
return
}
connections := m["connections"].(map[interface{}]interface{})
for s, confT := range connections {
conf := confT.(map[interface{}]interface{})
@@ -90,13 +91,14 @@ func SetEnv(fileContext []byte) []byte {
}
nS := arr2[1]
st, et := GetBrackets(nS, '"', '"')
key := nS[st : et+1]
key := strings.Trim(nS[st:et+1], "\"")
nS = nS[et+1:]
st, et = GetBrackets(nS, '"', '"')
val := nS[st : et+1]
key = strings.Trim(key, "\"")
val = strings.Trim(val, "\"")
var val string
if et > 0 {
val = nS[st : et+1]
val = strings.Trim(val, "\"")
}
envVal := os.Getenv(key)
if envVal != "" {
val = envVal

View File

@@ -321,6 +321,10 @@ type DB struct {
db *sql.DB
}
func (d *DB) GetDB() *sql.DB {
return d.db
}
func (d *DB) tableColumns() map[string][]tableColumn {
var sqlStr = "SELECT tablename FROM pg_tables WHERE schemaname = 'public'"

View File

@@ -6,6 +6,7 @@ import (
)
var _BeanCommandSingle *BeanCommand
var _CurdCommandSingle *CurdCommand
var _DevCommandSingle *DevCommand
var _JsSingle *Js
var _OrmCommandSingle *OrmCommand
@@ -16,6 +17,7 @@ var _SwaggerCommandSingle *SwaggerCommand
func GetAllProvider() []interface{} {
return []interface{}{
NewBeanCommand(),
NewCurdCommand(),
NewDevCommand(),
NewJs(),
NewOrmCommand(),
@@ -32,6 +34,13 @@ func NewBeanCommand() *BeanCommand {
}
return _BeanCommandSingle
}
func NewCurdCommand() *CurdCommand {
if _CurdCommandSingle == nil {
_CurdCommandSingle = &CurdCommand{}
providers.AfterProvider(_CurdCommandSingle, "")
}
return _CurdCommandSingle
}
func NewDevCommand() *DevCommand {
if _DevCommandSingle == nil {
_DevCommandSingle = &DevCommand{}

15
go.mod
View File

@@ -3,16 +3,19 @@ module github.com/go-home-admin/toolset
go 1.16
require (
github.com/ctfang/command v1.0.0
github.com/gin-gonic/gin v1.8.0 // indirect
github.com/ctfang/command v1.0.1
github.com/go-home-admin/home v0.0.3
github.com/go-playground/validator/v10 v10.11.0 // indirect
github.com/go-sql-driver/mysql v1.6.0
github.com/joho/godotenv v1.4.0
github.com/lib/pq v1.10.6
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect
golang.org/x/net v0.0.0-20220531201128-c960675eff93 // indirect
gopkg.in/yaml.v2 v2.4.0
gorm.io/driver/mysql v1.3.4 // indirect
gorm.io/gorm v1.23.5 // indirect
)
require (
github.com/gin-gonic/gin v1.8.0 // indirect
github.com/go-playground/validator/v10 v10.11.0 // indirect
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect
golang.org/x/net v0.0.0-20220531201128-c960675eff93 // indirect
gorm.io/driver/mysql v1.3.4 // indirect
)

4
go.sum
View File

@@ -4,8 +4,8 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/ctfang/command v1.0.0 h1:CurDflG0o7GntcRWWzbFlKdR3CMZUF/627Rv/Ii8HNE=
github.com/ctfang/command v1.0.0/go.mod h1:iWJcUCwZReswQ7T5IaRE6ZvGzZ/m9v53Z1na20JttV8=
github.com/ctfang/command v1.0.1 h1:e3GA3c68h153zs/AtTslY1+HpXMzrGkUqA8R4Qj9dtg=
github.com/ctfang/command v1.0.1/go.mod h1:iWJcUCwZReswQ7T5IaRE6ZvGzZ/m9v53Z1na20JttV8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=