Files
sponge/pkg/sql2code/sql2code.go
2022-10-17 23:11:21 +08:00

152 lines
3.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package sql2code
import (
"errors"
"fmt"
"os"
"github.com/zhufuyi/sponge/pkg/sql2code/parser"
)
// Args 参数
type Args struct {
SQL string // DDL sql
DDLFile string // 读取文件的DDL sql
DBDsn string // 从db获取表的DDL sql
DBTable string
Package string // 生成字段的包名(只有model类型有效)
GormType bool // 是否显示gorm type名称(只有model类型代码有效)
JSONTag bool // 是否包括json tag
JSONNamedType int // json命名类型0:和列名一致,其他值表示驼峰
IsEmbed bool // 是否嵌入gorm.Model
CodeType string // 指定生成代码用途支持4中类型分别是 model(默认), json, dao, handler
ForceTableName bool
Charset string
Collation string
TablePrefix string
ColumnPrefix string
NoNullType bool
NullStyle string
}
func (a *Args) checkValid() error {
if a.SQL == "" && a.DDLFile == "" && (a.DBDsn == "" && a.DBTable == "") {
return errors.New("you must specify sql or ddl file")
}
return nil
}
func getSQL(args *Args) (string, error) {
if args.SQL != "" {
return args.SQL, nil
}
sql := ""
if args.DDLFile != "" {
b, err := os.ReadFile(args.DDLFile)
if err != nil {
return sql, fmt.Errorf("read %s failed, %s", args.DDLFile, err)
}
return string(b), nil
} else if args.DBDsn != "" {
if args.DBTable == "" {
return sql, errors.New("miss mysql table")
}
sqlStr, err := parser.GetTableInfo(args.DBDsn, args.DBTable)
if err != nil {
return sql, err
}
return sqlStr, nil
}
return sql, errors.New("no SQL input(-sql|-f|-db-dsn)")
}
func getOptions(args *Args) []parser.Option {
var opts []parser.Option
if args.Charset != "" {
opts = append(opts, parser.WithCharset(args.Charset))
}
if args.Collation != "" {
opts = append(opts, parser.WithCollation(args.Collation))
}
if args.JSONTag {
opts = append(opts, parser.WithJSONTag(args.JSONNamedType))
}
if args.TablePrefix != "" {
opts = append(opts, parser.WithTablePrefix(args.TablePrefix))
}
if args.ColumnPrefix != "" {
opts = append(opts, parser.WithColumnPrefix(args.ColumnPrefix))
}
if args.NoNullType {
opts = append(opts, parser.WithNoNullType())
}
if args.IsEmbed {
opts = append(opts, parser.WithEmbed())
}
if args.NullStyle != "" {
switch args.NullStyle {
case "sql":
opts = append(opts, parser.WithNullStyle(parser.NullInSql))
case "ptr":
opts = append(opts, parser.WithNullStyle(parser.NullInPointer))
default:
fmt.Printf("invalid null style: %s\n", args.NullStyle)
return nil
}
} else {
opts = append(opts, parser.WithNullStyle(parser.NullDisable))
}
if args.Package != "" {
opts = append(opts, parser.WithPackage(args.Package))
}
if args.GormType {
opts = append(opts, parser.WithGormType())
}
if args.ForceTableName {
opts = append(opts, parser.WithForceTableName())
}
return opts
}
// GenerateOne 根据sql生成gorm代码sql可以从参数、文件、db三种方式获取优先从高到低
func GenerateOne(args *Args) (string, error) {
codes, err := Generate(args)
if err != nil {
return "", err
}
if args.CodeType == "" {
args.CodeType = parser.CodeTypeModel // 默认为model code
}
out, ok := codes[args.CodeType]
if !ok {
return "", fmt.Errorf("unknown code type %s", args.CodeType)
}
return out, nil
}
// Generate 生成model, json, dao, handler不同用途代码
func Generate(args *Args) (map[string]string, error) {
if err := args.checkValid(); err != nil {
return nil, err
}
sql, err := getSQL(args)
if err != nil {
return nil, err
}
opt := getOptions(args)
return parser.ParseSQL(sql, opt...)
}