mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-10-06 01:07:08 +08:00
feat: support custom table primary key type and name
This commit is contained in:
@@ -4,12 +4,100 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql"
|
||||
"github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types"
|
||||
"github.com/jinzhu/inflection"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zhufuyi/sqlparser/dependency/mysql"
|
||||
"github.com/zhufuyi/sqlparser/dependency/types"
|
||||
)
|
||||
|
||||
func TestParseSql(t *testing.T) {
|
||||
func TestParseSQL(t *testing.T) {
|
||||
sqls := []string{`create table user
|
||||
(
|
||||
id bigint unsigned auto_increment
|
||||
primary key,
|
||||
created_at datetime null,
|
||||
updated_at datetime null,
|
||||
deleted_at datetime null,
|
||||
name char(50) not null comment '用户名',
|
||||
password char(100) not null comment '密码',
|
||||
email char(50) not null comment '邮件',
|
||||
phone bigint unsigned not null comment '手机号码',
|
||||
age tinyint not null comment '年龄',
|
||||
gender tinyint not null comment '性别,1:男,2:女,3:未知',
|
||||
status tinyint not null comment '账号状态,1:未激活,2:已激活,3:封禁',
|
||||
login_state tinyint not null comment '登录状态,1:未登录,2:已登录',
|
||||
constraint user_email_uindex
|
||||
unique (email)
|
||||
);`,
|
||||
|
||||
`create table user_order
|
||||
(
|
||||
id varchar(36) not null comment '订单id'
|
||||
primary key,
|
||||
product_id varchar(36) not null comment '商品id',
|
||||
user_id bigint unsigned not null comment '用户id',
|
||||
status smallint null comment '0:未支付, 1:已支付, 2:已取消',
|
||||
created_at timestamp null comment '创建时间',
|
||||
updated_at timestamp null comment '更新时间'
|
||||
);`,
|
||||
|
||||
`create table user_str
|
||||
(
|
||||
user_id varchar(36) not null comment '用户id'
|
||||
primary key,
|
||||
username varchar(50) not null comment '用户名',
|
||||
email varchar(100) not null comment '邮箱',
|
||||
created_at datetime null comment '创建时间',
|
||||
constraint email
|
||||
unique (email)
|
||||
);`,
|
||||
|
||||
`create table user_no_primary
|
||||
(
|
||||
username varchar(50) not null comment '用户名',
|
||||
email varchar(100) not null comment '邮箱',
|
||||
user_id varchar(36) not null comment '用户id',
|
||||
created_at datetime null comment '创建时间',
|
||||
constraint email
|
||||
unique (email)
|
||||
);`}
|
||||
|
||||
for _, sql := range sqls {
|
||||
codes, err := ParseSQL(sql, WithJSONTag(0), WithEmbed())
|
||||
assert.Nil(t, err)
|
||||
for k, v := range codes {
|
||||
assert.NotEmpty(t, k)
|
||||
assert.NotEmpty(t, v)
|
||||
}
|
||||
//printCode(codes)
|
||||
|
||||
codes, err = ParseSQL(sql, WithJSONTag(1), WithWebProto(), WithDBDriver(DBDriverMysql))
|
||||
assert.Nil(t, err)
|
||||
for k, v := range codes {
|
||||
assert.NotEmpty(t, k)
|
||||
assert.NotEmpty(t, v)
|
||||
}
|
||||
//printCode(codes)
|
||||
|
||||
codes, err = ParseSQL(sql, WithJSONTag(0), WithDBDriver(DBDriverPostgresql))
|
||||
assert.Nil(t, err)
|
||||
for k, v := range codes {
|
||||
assert.NotEmpty(t, k)
|
||||
assert.NotEmpty(t, v)
|
||||
}
|
||||
//printCode(codes)
|
||||
|
||||
codes, err = ParseSQL(sql, WithJSONTag(0), WithDBDriver(DBDriverSqlite))
|
||||
assert.Nil(t, err)
|
||||
for k, v := range codes {
|
||||
assert.NotEmpty(t, k)
|
||||
assert.NotEmpty(t, v)
|
||||
}
|
||||
//printCode(codes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSqlWithTablePrefix(t *testing.T) {
|
||||
sql := `CREATE TABLE t_person_info (
|
||||
id BIGINT(11) PRIMARY KEY AUTO_INCREMENT NOT NULL COMMENT 'id',
|
||||
age INT(11) unsigned NULL,
|
||||
@@ -177,7 +265,7 @@ func Test_goTypeToProto(t *testing.T) {
|
||||
{GoType: "uint"},
|
||||
{GoType: "time.Time"},
|
||||
}
|
||||
v := goTypeToProto(fields, 1)
|
||||
v := goTypeToProto(fields, 1, false)
|
||||
assert.NotNil(t, v)
|
||||
}
|
||||
|
||||
@@ -206,14 +294,14 @@ func Test_initTemplate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetMysqlTableInfo(t *testing.T) {
|
||||
info, err := GetMysqlTableInfo("root:123456@(192.168.3.37:3306)/test", "user")
|
||||
info, err := GetMysqlTableInfo("root:123456@(192.168.3.37:3306)/account", "user_order")
|
||||
t.Log(err, info)
|
||||
}
|
||||
|
||||
func TestGetPostgresqlTableInfo(t *testing.T) {
|
||||
var (
|
||||
dbname = "account"
|
||||
tableName = "user_example"
|
||||
tableName = "user_order"
|
||||
dsn = fmt.Sprintf("host=192.168.3.37 port=5432 user=root password=123456 dbname=%s sslmode=disable", dbname)
|
||||
)
|
||||
|
||||
@@ -228,8 +316,13 @@ func TestGetPostgresqlTableInfo(t *testing.T) {
|
||||
t.Log(fieldTypes)
|
||||
}
|
||||
|
||||
func Test_getPostgresqlTableFields(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
_, _ = getPostgresqlTableFields(nil, "foobar")
|
||||
}
|
||||
|
||||
func TestGetSqliteTableInfo(t *testing.T) {
|
||||
info, err := GetSqliteTableInfo("..\\..\\..\\test\\sql\\sqlite\\sponge.db", "user_example")
|
||||
info, err := GetSqliteTableInfo("..\\..\\..\\test\\sql\\sqlite\\sponge.db", "user_order")
|
||||
t.Log(err, info)
|
||||
}
|
||||
|
||||
@@ -260,7 +353,7 @@ func TestConvertToSQLByPgFields(t *testing.T) {
|
||||
t.Log(sql, tps)
|
||||
}
|
||||
|
||||
func Test_toMysqlTable(t *testing.T) {
|
||||
func Test_PGField_getMysqlType(t *testing.T) {
|
||||
fields := []*PGField{
|
||||
{Type: "smallint"},
|
||||
{Type: "bigint"},
|
||||
@@ -277,7 +370,23 @@ func Test_toMysqlTable(t *testing.T) {
|
||||
{Type: "boolean"},
|
||||
}
|
||||
for _, field := range fields {
|
||||
t.Log(toMysqlType(field), getType(field))
|
||||
t.Log(field.getMysqlType(), getType(field))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SqliteField_getMysqlType(t *testing.T) {
|
||||
fields := []*SqliteField{
|
||||
{Type: "integer"},
|
||||
{Type: "text"},
|
||||
{Type: "real"},
|
||||
{Type: "numeric"},
|
||||
{Type: "blob"},
|
||||
{Type: "datetime"},
|
||||
{Type: "boolean"},
|
||||
{Type: "unknown_type"},
|
||||
}
|
||||
for _, field := range fields {
|
||||
t.Log(field.getMysqlType())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,9 +397,9 @@ func printCode(code map[string]string) {
|
||||
}
|
||||
|
||||
func printPGFields(fields []*PGField) {
|
||||
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v\n", "Name", "Type", "Length", "Lengthvar", "Notnull", "Comment")
|
||||
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v %-20v\n", "Name", "Type", "Length", "Lengthvar", "Notnull", "Comment", "IsPrimaryKey")
|
||||
for _, p := range fields {
|
||||
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v\n", p.Name, p.Type, p.Length, p.Lengthvar, p.Notnull, p.Comment)
|
||||
fmt.Printf("%-20v %-20v %-20v %-20v %-20v %-20v %-20v\n", p.Name, p.Type, p.Length, p.Lengthvar, p.Notnull, p.Comment, p.IsPrimaryKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,3 +551,82 @@ func Test_embedTimeFields(t *testing.T) {
|
||||
fields = embedTimeField(names, []*MgoField{})
|
||||
t.Log(fields)
|
||||
}
|
||||
|
||||
func TestCrudInfo(t *testing.T) {
|
||||
data := tmplData{
|
||||
TableName: "User",
|
||||
TName: "user",
|
||||
NameFunc: false,
|
||||
RawTableName: "user",
|
||||
Fields: []tmplField{
|
||||
{
|
||||
ColName: "name",
|
||||
Name: "Name",
|
||||
GoType: "string",
|
||||
Tag: "json:\"name\"",
|
||||
Comment: "姓名",
|
||||
JSONName: "name",
|
||||
DBDriver: "mysql",
|
||||
},
|
||||
{
|
||||
ColName: "age",
|
||||
Name: "Age",
|
||||
GoType: "int",
|
||||
Tag: "json:\"age\"",
|
||||
Comment: "年龄",
|
||||
JSONName: "age",
|
||||
DBDriver: "mysql",
|
||||
},
|
||||
{
|
||||
ColName: "created_at",
|
||||
Name: "CreatedAt",
|
||||
GoType: "time.Time",
|
||||
Tag: "json:\"created_at\"",
|
||||
Comment: "创建时间",
|
||||
JSONName: "createdAt",
|
||||
DBDriver: "mysql",
|
||||
},
|
||||
},
|
||||
Comment: "用户信息",
|
||||
SubStructs: "",
|
||||
ProtoSubStructs: "",
|
||||
DBDriver: "mysql",
|
||||
}
|
||||
|
||||
info := newCrudInfo(data)
|
||||
|
||||
isPrimary := info.isIDPrimaryKey()
|
||||
assert.Equal(t, false, isPrimary)
|
||||
|
||||
code := info.getCode()
|
||||
assert.Contains(t, code, `"tableNameCamel":"User","tableNameCamelFCL":"user"`)
|
||||
|
||||
grpcValidation := info.GetGRPCProtoValidation()
|
||||
assert.Contains(t, grpcValidation, "validate.rules")
|
||||
|
||||
webValidation := info.GetWebProtoValidation()
|
||||
assert.Contains(t, webValidation, "validate.rules")
|
||||
|
||||
info = nil
|
||||
_ = info.isIDPrimaryKey()
|
||||
_ = info.getCode()
|
||||
_ = info.GetGRPCProtoValidation()
|
||||
_ = info.GetWebProtoValidation()
|
||||
}
|
||||
|
||||
func Test_customEndOfLetterToLower(t *testing.T) {
|
||||
names := []string{
|
||||
"ID",
|
||||
"IP",
|
||||
"userID",
|
||||
"orderID",
|
||||
"LocalIP",
|
||||
"bus",
|
||||
"BUS",
|
||||
"x",
|
||||
"s",
|
||||
}
|
||||
for _, name := range names {
|
||||
t.Log(customEndOfLetterToLower(name, inflection.Plural(name)))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user