From 284b563f81c8fc6957feddda68e42f73dc51e3ce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 09:33:26 +0800 Subject: [PATCH] Initalize --- License | 21 ++++++ README.md | 15 ++++ go.mod | 5 ++ migrator.go | 211 ++++++++++++++++++++++++++++++++++++++++++++++++++++ sqlite.go | 80 ++++++++++++++++++++ 5 files changed, 332 insertions(+) create mode 100644 License create mode 100644 README.md create mode 100644 go.mod create mode 100644 migrator.go create mode 100644 sqlite.go diff --git a/License b/License new file mode 100644 index 0000000..037e165 --- /dev/null +++ b/License @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c02a9f7 --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +# GORM Sqlite Driver + +## USAGE + +```go +import ( + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// github.com/mattn/go-sqlite3 +db, err := gorm.Open(mysql.Open("gorm.db"), &gorm.Config{}) +``` + +Checkout [https://gorm.io](https://gorm.io) for details. diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d7296d9 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module gorm.io/driver/sqlite + +go 1.14 + +require github.com/mattn/go-sqlite3 v2.0.3+incompatible diff --git a/migrator.go b/migrator.go new file mode 100644 index 0000000..14c682c --- /dev/null +++ b/migrator.go @@ -0,0 +1,211 @@ +package sqlite + +import ( + "fmt" + "regexp" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +type Migrator struct { + migrator.Migrator +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) HasColumn(value interface{}, name string) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", + ).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) AlterColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, "?") + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + + createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) + return m.DB.Exec(createSQL, m.FullDataTypeOf(field)).Error + } else { + return err + } + } else { + return fmt.Errorf("failed to alter field with name %v", name) + } + }) +} + +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, "") + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + if columnType.Name() != name { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + } + + createSQL = fmt.Sprintf("PRAGMA foreign_keys=off;BEGIN TRANSACTION;"+createSQL+";INSERT INTO `%v`(%v) SELECT %v FROM `%v`;DROP TABLE `%v`;ALTER TABLE `%v` RENAME TO `%v`;COMMIT;", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table, stmt.Table, newTableName, stmt.Table) + + return m.DB.Exec(createSQL).Error + } else { + return err + } + }) +} + +func (m Migrator) CreateConstraint(interface{}, string) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) DropConstraint(interface{}, string) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) CurrentDatabase() (name string) { + var null interface{} + m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, + ).Row().Scan(&count) + return nil + }) + return count > 0 +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + var sql string + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) + if sql != "" { + return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error + } + return fmt.Errorf("failed to find index with name %v", oldName) + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error + }) +} diff --git a/sqlite.go b/sqlite.go new file mode 100644 index 0000000..238ad7f --- /dev/null +++ b/sqlite.go @@ -0,0 +1,80 @@ +package sqlite + +import ( + "database/sql" + + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" + _ "github.com/mattn/go-sqlite3" +) + +type Dialector struct { + DSN string +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{DSN: dsn} +} + +func (dialector Dialector) Name() string { + return "sqlite" +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + // register callbacks + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + LastInsertIDReversed: true, + }) + db.ConnPool, err = sql.Open("sqlite3", dialector.DSN) + return +} + +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, + }}} +} + +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') +} + +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') +} + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "numeric" + case schema.Int, schema.Uint: + if field.AutoIncrement { + // https://www.sqlite.org/autoinc.html + return "integer PRIMARY KEY AUTOINCREMENT" + } else { + return "integer" + } + case schema.Float: + return "real" + case schema.String: + return "text" + case schema.Time: + return "datetime" + case schema.Bytes: + return "blob" + } + + return "" +}