diff --git a/ddlmod.go b/ddlmod.go index 87886d2..e6dcbc3 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -16,12 +16,23 @@ var ( indexRegexp = regexp.MustCompile(fmt.Sprintf("(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d-]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator)) tableRegexp = regexp.MustCompile(fmt.Sprintf("(?is)(CREATE TABLE [%v]?[\\w\\d-]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator)) separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator)) - columnsRegexp = regexp.MustCompile(fmt.Sprintf("\\([%v]?([\\w\\d]+)[%v]?(?:,[%v]?([\\w\\d]+)[%v]){0,}\\)", sqliteSeparator, sqliteSeparator, sqliteSeparator, sqliteSeparator)) + columnsRegexp = regexp.MustCompile(fmt.Sprintf("[(,][%v]?(\\w+)[%v]?", sqliteSeparator, sqliteSeparator)) columnRegexp = regexp.MustCompile(fmt.Sprintf("^[%v]?([\\w\\d]+)[%v]?\\s+([\\w\\(\\)\\d]+)(.*)$", sqliteSeparator, sqliteSeparator)) defaultValueRegexp = regexp.MustCompile("(?i) DEFAULT \\(?(.+)?\\)?( |COLLATE|GENERATED|$)") regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) ) +func getAllColumns(s string) []string { + allMatches := columnsRegexp.FindAllStringSubmatch(s, -1) + columns := make([]string, 0, len(allMatches)) + for _, matches := range allMatches { + if len(matches) > 1 { + columns = append(columns, matches[1]) + } + } + return columns +} + type ddl struct { head string fields []string @@ -98,15 +109,12 @@ func parseDDL(strs ...string) (*ddl, error) { } if strings.HasPrefix(fUpper, "PRIMARY KEY") { - matches := columnsRegexp.FindStringSubmatch(f) - if len(matches) > 1 { - for _, name := range matches[1:] { - for idx, column := range result.columns { - if column.NameValue.String == name { - column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} - result.columns[idx] = column - break - } + for _, name := range getAllColumns(f) { + for idx, column := range result.columns { + if column.NameValue.String == name { + column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + result.columns[idx] = column + break } } } @@ -151,9 +159,9 @@ func parseDDL(strs ...string) (*ddl, error) { } } } else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 { - if columns := columnsRegexp.FindStringSubmatch(matches[1]); len(columns) == 1 { + for _, column := range getAllColumns(matches[1]) { for idx, c := range result.columns { - if c.NameValue.String == columns[0] { + if c.NameValue.String == column { c.UniqueValue = sql.NullBool{Bool: true, Valid: true} result.columns[idx] = c } diff --git a/ddlmod_test.go b/ddlmod_test.go index edc1c47..6ec7db7 100644 --- a/ddlmod_test.go +++ b/ddlmod_test.go @@ -20,7 +20,7 @@ func TestParseDDL(t *testing.T) { "CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)", }, 6, []migrator.ColumnType{ {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}}, - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, @@ -56,11 +56,29 @@ func TestParseDDL(t *testing.T) { ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, - UniqueValue: sql.NullBool{Valid: true}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}, }, }, }, + { + "unique index", + []string{ + "CREATE TABLE `test-b` (`field` integer NOT NULL)", + "CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0", + }, + 1, + []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "field", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + }, + }, + }, } for _, p := range params {