mirror of
https://github.com/zhufuyi/sponge.git
synced 2025-12-24 10:40:55 +08:00
feat: support custom table primary key type and name
This commit is contained in:
@@ -8,18 +8,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PGField postgresql field
|
||||
type PGField struct {
|
||||
Name string `gorm:"column:name;" json:"name"`
|
||||
Type string `gorm:"column:type;" json:"type"`
|
||||
Comment string `gorm:"column:comment;" json:"comment"`
|
||||
Length int `gorm:"column:length;" json:"length"`
|
||||
Lengthvar int `gorm:"column:lengthvar;" json:"lengthvar"`
|
||||
Notnull bool `gorm:"column:notnull;" json:"notnull"`
|
||||
}
|
||||
|
||||
// GetPostgresqlTableInfo get table info from postgres
|
||||
func GetPostgresqlTableInfo(dsn string, tableName string) ([]*PGField, error) {
|
||||
func GetPostgresqlTableInfo(dsn string, tableName string) (PGFields, error) {
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetPostgresqlTableInfo error: %v", err)
|
||||
@@ -29,51 +19,47 @@ func GetPostgresqlTableInfo(dsn string, tableName string) ([]*PGField, error) {
|
||||
return getPostgresqlTableFields(db, tableName)
|
||||
}
|
||||
|
||||
func getPostgresqlTableFields(db *gorm.DB, tableName string) ([]*PGField, error) {
|
||||
query := fmt.Sprintf(`SELECT a.attname AS name, t.typname AS type, a.attlen AS length, a.atttypmod AS lengthvar, a.attnotnull AS notnull, b.description AS comment
|
||||
FROM pg_class c, pg_attribute a
|
||||
LEFT JOIN pg_description b
|
||||
ON a.attrelid = b.objoid
|
||||
AND a.attnum = b.objsubid, pg_type t
|
||||
WHERE c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
AND a.attrelid = c.oid
|
||||
AND a.atttypid = t.oid
|
||||
ORDER BY a.attnum;`, tableName)
|
||||
|
||||
var fields []*PGField
|
||||
result := db.Raw(query).Scan(&fields)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get table fields: %v", result.Error)
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
// ConvertToSQLByPgFields convert to mysql table ddl
|
||||
func ConvertToSQLByPgFields(tableName string, fields []*PGField) (string, map[string]string) {
|
||||
func ConvertToSQLByPgFields(tableName string, fields PGFields) (string, map[string]string) {
|
||||
fieldStr := ""
|
||||
pgTypeMap := make(map[string]string) // name:type
|
||||
if len(fields) == 0 {
|
||||
return "", pgTypeMap
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
pgTypeMap[field.Name] = getType(field)
|
||||
sqlType := toMysqlType(field)
|
||||
if field.Name == "id" {
|
||||
fieldStr += fmt.Sprintf(" %s bigint unsigned primary key,\n", field.Name)
|
||||
continue
|
||||
}
|
||||
sqlType := field.getMysqlType()
|
||||
notnullStr := "not null"
|
||||
if !field.Notnull {
|
||||
notnullStr = "null"
|
||||
}
|
||||
fieldStr += fmt.Sprintf(" `%s` %s %s comment '%s',\n", field.Name, sqlType, notnullStr, field.Comment)
|
||||
}
|
||||
fieldStr = strings.TrimSuffix(fieldStr, ",\n")
|
||||
|
||||
primaryField := fields.getPrimaryField()
|
||||
if primaryField != nil {
|
||||
fieldStr += fmt.Sprintf(" PRIMARY KEY (`%s`)\n", primaryField.Name)
|
||||
} else {
|
||||
fieldStr = strings.TrimSuffix(fieldStr, ",\n")
|
||||
}
|
||||
sqlStr := fmt.Sprintf("CREATE TABLE `%s` (\n%s\n);", tableName, fieldStr)
|
||||
return sqlStr, pgTypeMap
|
||||
}
|
||||
|
||||
// PGField postgresql field
|
||||
type PGField struct {
|
||||
Name string `gorm:"column:name;" json:"name"`
|
||||
Type string `gorm:"column:type;" json:"type"`
|
||||
Comment string `gorm:"column:comment;" json:"comment"`
|
||||
Length int `gorm:"column:length;" json:"length"`
|
||||
Lengthvar int `gorm:"column:lengthvar;" json:"lengthvar"`
|
||||
Notnull bool `gorm:"column:notnull;" json:"notnull"`
|
||||
IsPrimaryKey bool `gorm:"column:is_primary_key;" json:"is_primary_key"`
|
||||
}
|
||||
|
||||
// nolint
|
||||
func toMysqlType(field *PGField) string {
|
||||
func (field *PGField) getMysqlType() string {
|
||||
switch field.Type {
|
||||
case "smallint", "integer", "smallserial", "serial", "int2", "int4":
|
||||
return "int"
|
||||
@@ -117,6 +103,79 @@ func toMysqlType(field *PGField) string {
|
||||
return field.Type
|
||||
}
|
||||
|
||||
type PGFields []*PGField
|
||||
|
||||
func (fields PGFields) getPrimaryField() *PGField {
|
||||
var f *PGField
|
||||
for _, field := range fields {
|
||||
if field.IsPrimaryKey || field.Name == "id" {
|
||||
f = field
|
||||
return f
|
||||
}
|
||||
}
|
||||
/*
|
||||
// if no primary key, find the first xxx_id field
|
||||
if f == nil {
|
||||
for _, field := range fields {
|
||||
if strings.HasSuffix(field.Name, "_id") {
|
||||
f = field
|
||||
f.IsPrimaryKey = true
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if no xxx_id field, find the first field
|
||||
if f == nil {
|
||||
for _, field := range fields {
|
||||
f = field
|
||||
f.IsPrimaryKey = true
|
||||
return f
|
||||
}
|
||||
}
|
||||
*/
|
||||
return f
|
||||
}
|
||||
|
||||
func getPostgresqlTableFields(db *gorm.DB, tableName string) (PGFields, error) {
|
||||
query := fmt.Sprintf(`SELECT
|
||||
a.attname AS name,
|
||||
t.typname AS type,
|
||||
a.attlen AS length,
|
||||
a.atttypmod AS lengthvar,
|
||||
a.attnotnull AS notnull,
|
||||
b.description AS comment,
|
||||
CASE
|
||||
WHEN pk.constraint_type = 'PRIMARY KEY' THEN true
|
||||
ELSE false
|
||||
END AS is_primary_key
|
||||
FROM pg_class c
|
||||
JOIN pg_attribute a ON a.attrelid = c.oid
|
||||
LEFT JOIN pg_description b ON a.attrelid = b.objoid AND a.attnum = b.objsubid
|
||||
JOIN pg_type t ON a.atttypid = t.oid
|
||||
LEFT JOIN (
|
||||
SELECT
|
||||
kcu.column_name,
|
||||
con.constraint_type
|
||||
FROM information_schema.table_constraints con
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON con.constraint_name = kcu.constraint_name
|
||||
WHERE con.constraint_type = 'PRIMARY KEY'
|
||||
AND con.table_name = '%s'
|
||||
) AS pk ON a.attname = pk.column_name
|
||||
WHERE c.relname = '%s'
|
||||
AND a.attnum > 0
|
||||
ORDER BY a.attnum;`, tableName, tableName)
|
||||
|
||||
var fields PGFields
|
||||
result := db.Raw(query).Scan(&fields)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get table fields: %v", result.Error)
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func getType(field *PGField) string {
|
||||
switch field.Type {
|
||||
case "character", "character varying", "varchar", "char", "bpchar":
|
||||
|
||||
Reference in New Issue
Block a user