mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-10-05 16:57:07 +08:00
feat: implement sponge commands
This commit is contained in:
151
pkg/sql2code/sql2code.go
Normal file
151
pkg/sql2code/sql2code.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package sql2code
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/zhufuyi/sponge/pkg/sql2code/parser"
|
||||
)
|
||||
|
||||
// Args 参数
|
||||
type Args struct {
|
||||
SQL string // DDL sql
|
||||
|
||||
DDLFile string // 读取文件的DDL sql
|
||||
|
||||
DBDsn string // 从db获取表的DDL sql
|
||||
DBTable string
|
||||
|
||||
Package string // 生成字段的包名(只有model类型有效)
|
||||
GormType bool // 是否显示gorm type名称(只有model类型代码有效)
|
||||
JSONTag bool // 是否包括json tag
|
||||
JSONNamedType int // json命名类型,0:和列名一致,其他值表示驼峰
|
||||
IsEmbed bool // 是否嵌入gorm.Model
|
||||
CodeType string // 指定生成代码用途,支持4中类型,分别是 model(默认), json, dao, handler
|
||||
ForceTableName bool
|
||||
Charset string
|
||||
Collation string
|
||||
TablePrefix string
|
||||
ColumnPrefix string
|
||||
NoNullType bool
|
||||
NullStyle string
|
||||
}
|
||||
|
||||
func (a *Args) checkValid() error {
|
||||
if a.SQL == "" && a.DDLFile == "" && (a.DBDsn == "" && a.DBTable == "") {
|
||||
return errors.New("you must specify sql or ddl file")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSQL(args *Args) (string, error) {
|
||||
if args.SQL != "" {
|
||||
return args.SQL, nil
|
||||
}
|
||||
|
||||
sql := ""
|
||||
if args.DDLFile != "" {
|
||||
b, err := os.ReadFile(args.DDLFile)
|
||||
if err != nil {
|
||||
return sql, fmt.Errorf("read %s failed, %s", args.DDLFile, err)
|
||||
}
|
||||
return string(b), nil
|
||||
} else if args.DBDsn != "" {
|
||||
if args.DBTable == "" {
|
||||
return sql, errors.New("miss mysql table")
|
||||
}
|
||||
sqlStr, err := parser.GetTableInfo(args.DBDsn, args.DBTable)
|
||||
if err != nil {
|
||||
return sql, err
|
||||
}
|
||||
return sqlStr, nil
|
||||
}
|
||||
|
||||
return sql, errors.New("no SQL input(-sql|-f|-db-dsn)")
|
||||
}
|
||||
|
||||
func getOptions(args *Args) []parser.Option {
|
||||
var opts []parser.Option
|
||||
|
||||
if args.Charset != "" {
|
||||
opts = append(opts, parser.WithCharset(args.Charset))
|
||||
}
|
||||
if args.Collation != "" {
|
||||
opts = append(opts, parser.WithCollation(args.Collation))
|
||||
}
|
||||
if args.JSONTag {
|
||||
opts = append(opts, parser.WithJSONTag(args.JSONNamedType))
|
||||
}
|
||||
if args.TablePrefix != "" {
|
||||
opts = append(opts, parser.WithTablePrefix(args.TablePrefix))
|
||||
}
|
||||
if args.ColumnPrefix != "" {
|
||||
opts = append(opts, parser.WithColumnPrefix(args.ColumnPrefix))
|
||||
}
|
||||
if args.NoNullType {
|
||||
opts = append(opts, parser.WithNoNullType())
|
||||
}
|
||||
if args.IsEmbed {
|
||||
opts = append(opts, parser.WithEmbed())
|
||||
}
|
||||
|
||||
if args.NullStyle != "" {
|
||||
switch args.NullStyle {
|
||||
case "sql":
|
||||
opts = append(opts, parser.WithNullStyle(parser.NullInSql))
|
||||
case "ptr":
|
||||
opts = append(opts, parser.WithNullStyle(parser.NullInPointer))
|
||||
default:
|
||||
fmt.Printf("invalid null style: %s\n", args.NullStyle)
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
opts = append(opts, parser.WithNullStyle(parser.NullDisable))
|
||||
}
|
||||
if args.Package != "" {
|
||||
opts = append(opts, parser.WithPackage(args.Package))
|
||||
}
|
||||
if args.GormType {
|
||||
opts = append(opts, parser.WithGormType())
|
||||
}
|
||||
if args.ForceTableName {
|
||||
opts = append(opts, parser.WithForceTableName())
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
// GenerateOne 根据sql生成gorm代码,sql可以从参数、文件、db三种方式获取,优先从高到低
|
||||
func GenerateOne(args *Args) (string, error) {
|
||||
codes, err := Generate(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if args.CodeType == "" {
|
||||
args.CodeType = parser.CodeTypeModel // 默认为model code
|
||||
}
|
||||
out, ok := codes[args.CodeType]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown code type %s", args.CodeType)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Generate 生成model, json, dao, handler不同用途代码
|
||||
func Generate(args *Args) (map[string]string, error) {
|
||||
if err := args.checkValid(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sql, err := getSQL(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opt := getOptions(args)
|
||||
|
||||
return parser.ParseSQL(sql, opt...)
|
||||
}
|
Reference in New Issue
Block a user