feat: support custom table primary key type and name

This commit is contained in:
zhuyasen
2024-11-04 20:30:28 +08:00
parent 4deb203a02
commit 6eb9e51b9d
51 changed files with 5514 additions and 344 deletions

View File

@@ -4,12 +4,100 @@ import (
"fmt"
"testing"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql"
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types"
"github.com/jinzhu/inflection"
"github.com/stretchr/testify/assert"
"github.com/zhufuyi/sqlparser/dependency/mysql"
"github.com/zhufuyi/sqlparser/dependency/types"
)
func TestParseSql(t *testing.T) {
func TestParseSQL(t *testing.T) {
sqls := []string{`create table user
(
id bigint unsigned auto_increment
primary key,
created_at datetime null,
updated_at datetime null,
deleted_at datetime null,
name char(50) not null comment '用户名',
password char(100) not null comment '密码',
email char(50) not null comment '邮件',
phone bigint unsigned not null comment '手机号码',
age tinyint not null comment '年龄',
gender tinyint not null comment '性别1:男2:女3:未知',
status tinyint not null comment '账号状态1:未激活2:已激活3:封禁',
login_state tinyint not null comment '登录状态1:未登录2:已登录',
constraint user_email_uindex
unique (email)
);`,
`create table user_order
(
id varchar(36) not null comment '订单id'
primary key,
product_id varchar(36) not null comment '商品id',
user_id bigint unsigned not null comment '用户id',
status smallint null comment '0:未支付, 1:已支付, 2:已取消',
created_at timestamp null comment '创建时间',
updated_at timestamp null comment '更新时间'
);`,
`create table user_str
(
user_id varchar(36) not null comment '用户id'
primary key,
username varchar(50) not null comment '用户名',
email varchar(100) not null comment '邮箱',
created_at datetime null comment '创建时间',
constraint email
unique (email)
);`,
`create table user_no_primary
(
username varchar(50) not null comment '用户名',
email varchar(100) not null comment '邮箱',
user_id varchar(36) not null comment '用户id',
created_at datetime null comment '创建时间',
constraint email
unique (email)
);`}
for _, sql := range sqls {
codes, err := ParseSQL(sql, WithJSONTag(0), WithEmbed())
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
codes, err = ParseSQL(sql, WithJSONTag(1), WithWebProto(), WithDBDriver(DBDriverMysql))
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
codes, err = ParseSQL(sql, WithJSONTag(0), WithDBDriver(DBDriverPostgresql))
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
codes, err = ParseSQL(sql, WithJSONTag(0), WithDBDriver(DBDriverSqlite))
assert.Nil(t, err)
for k, v := range codes {
assert.NotEmpty(t, k)
assert.NotEmpty(t, v)
}
//printCode(codes)
}
}
func TestParseSqlWithTablePrefix(t *testing.T) {
sql := `CREATE TABLE t_person_info (
id BIGINT(11) PRIMARY KEY AUTO_INCREMENT NOT NULL COMMENT 'id',
age INT(11) unsigned NULL,
@@ -177,7 +265,7 @@ func Test_goTypeToProto(t *testing.T) {
{GoType: "uint"},
{GoType: "time.Time"},
}
v := goTypeToProto(fields, 1)
v := goTypeToProto(fields, 1, false)
assert.NotNil(t, v)
}
@@ -206,14 +294,14 @@ func Test_initTemplate(t *testing.T) {
}
func TestGetMysqlTableInfo(t *testing.T) {
info, err := GetMysqlTableInfo("root:123456@(192.168.3.37:3306)/test", "user")
info, err := GetMysqlTableInfo("root:123456@(192.168.3.37:3306)/account", "user_order")
t.Log(err, info)
}
func TestGetPostgresqlTableInfo(t *testing.T) {
var (
dbname = "account"
tableName = "user_example"
tableName = "user_order"
dsn = fmt.Sprintf("host=192.168.3.37 port=5432 user=root password=123456 dbname=%s sslmode=disable", dbname)
)
@@ -228,8 +316,13 @@ func TestGetPostgresqlTableInfo(t *testing.T) {
t.Log(fieldTypes)
}
func Test_getPostgresqlTableFields(t *testing.T) {
defer func() { _ = recover() }()
_, _ = getPostgresqlTableFields(nil, "foobar")
}
func TestGetSqliteTableInfo(t *testing.T) {
info, err := GetSqliteTableInfo("..\\..\\..\\test\\sql\\sqlite\\sponge.db", "user_example")
info, err := GetSqliteTableInfo("..\\..\\..\\test\\sql\\sqlite\\sponge.db", "user_order")
t.Log(err, info)
}
@@ -260,7 +353,7 @@ func TestConvertToSQLByPgFields(t *testing.T) {
t.Log(sql, tps)
}
func Test_toMysqlTable(t *testing.T) {
func Test_PGField_getMysqlType(t *testing.T) {
fields := []*PGField{
{Type: "smallint"},
{Type: "bigint"},
@@ -277,7 +370,23 @@ func Test_toMysqlTable(t *testing.T) {
{Type: "boolean"},
}
for _, field := range fields {
t.Log(toMysqlType(field), getType(field))
t.Log(field.getMysqlType(), getType(field))
}
}
func Test_SqliteField_getMysqlType(t *testing.T) {
fields := []*SqliteField{
{Type: "integer"},
{Type: "text"},
{Type: "real"},
{Type: "numeric"},
{Type: "blob"},
{Type: "datetime"},
{Type: "boolean"},
{Type: "unknown_type"},
}
for _, field := range fields {
t.Log(field.getMysqlType())
}
}
@@ -288,9 +397,9 @@ func printCode(code map[string]string) {
}
func printPGFields(fields []*PGField) {
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v\n", "Name", "Type", "Length", "Lengthvar", "Notnull", "Comment")
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v %-20v\n", "Name", "Type", "Length", "Lengthvar", "Notnull", "Comment", "IsPrimaryKey")
for _, p := range fields {
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v\n", p.Name, p.Type, p.Length, p.Lengthvar, p.Notnull, p.Comment)
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v %-20v\n", p.Name, p.Type, p.Length, p.Lengthvar, p.Notnull, p.Comment, p.IsPrimaryKey)
}
}
@@ -442,3 +551,82 @@ func Test_embedTimeFields(t *testing.T) {
fields = embedTimeField(names, []*MgoField{})
t.Log(fields)
}
func TestCrudInfo(t *testing.T) {
data := tmplData{
TableName: "User",
TName: "user",
NameFunc: false,
RawTableName: "user",
Fields: []tmplField{
{
ColName: "name",
Name: "Name",
GoType: "string",
Tag: "json:\"name\"",
Comment: "姓名",
JSONName: "name",
DBDriver: "mysql",
},
{
ColName: "age",
Name: "Age",
GoType: "int",
Tag: "json:\"age\"",
Comment: "年龄",
JSONName: "age",
DBDriver: "mysql",
},
{
ColName: "created_at",
Name: "CreatedAt",
GoType: "time.Time",
Tag: "json:\"created_at\"",
Comment: "创建时间",
JSONName: "createdAt",
DBDriver: "mysql",
},
},
Comment: "用户信息",
SubStructs: "",
ProtoSubStructs: "",
DBDriver: "mysql",
}
info := newCrudInfo(data)
isPrimary := info.isIDPrimaryKey()
assert.Equal(t, false, isPrimary)
code := info.getCode()
assert.Contains(t, code, `"tableNameCamel":"User","tableNameCamelFCL":"user"`)
grpcValidation := info.GetGRPCProtoValidation()
assert.Contains(t, grpcValidation, "validate.rules")
webValidation := info.GetWebProtoValidation()
assert.Contains(t, webValidation, "validate.rules")
info = nil
_ = info.isIDPrimaryKey()
_ = info.getCode()
_ = info.GetGRPCProtoValidation()
_ = info.GetWebProtoValidation()
}
func Test_customEndOfLetterToLower(t *testing.T) {
names := []string{
"ID",
"IP",
"userID",
"orderID",
"LocalIP",
"bus",
"BUS",
"x",
"s",
}
for _, name := range names {
t.Log(customEndOfLetterToLower(name, inflection.Plural(name)))
}
}