From 75cba1fa4865d4927398c1229c0e88d10b02f216 Mon Sep 17 00:00:00 2001 From: tangpanqing Date: Wed, 11 Jan 2023 23:33:20 +0800 Subject: [PATCH] update --- aorm.go | 54 +---- builder/crud.go | 43 ++-- builder/handle.go | 4 +- migrate_mssql/migrate.go | 40 +-- migrate_mysql/migrate.go | 44 +--- migrate_postgres/migrate.go | 69 +++--- migrate_sqlite3/migrate.go | 40 +-- migrator/migrator.go | 53 +--- model/AormDB.go | 52 ++++ model/AormTx.go | 42 ++++ model/model.go | 7 +- test/aorm_test.go | 470 ++++++++++++++++++------------------ 12 files changed, 438 insertions(+), 480 deletions(-) create mode 100644 model/AormDB.go create mode 100644 model/AormTx.go diff --git a/aorm.go b/aorm.go index 2ee446a..df8b5ce 100644 --- a/aorm.go +++ b/aorm.go @@ -7,52 +7,21 @@ import ( "github.com/tangpanqing/aorm/model" ) -// DbContent 数据库连接与数据库类型 -type DbContent struct { - DriverName string - DbLink *sql.DB -} - -func (dc *DbContent) Db() *sql.DB { - return dc.DbLink -} - -func (dc *DbContent) Begin() *sql.Tx { - tx, _ := dc.DbLink.Begin() - return tx -} - -func (dc *DbContent) Exec(query string, args ...interface{}) (sql.Result, error) { - return dc.Exec(query, args...) -} - -func (dc *DbContent) Prepare(query string) (*sql.Stmt, error) { - return dc.Prepare(query) -} - -func (dc *DbContent) Query(query string, args ...interface{}) (*sql.Rows, error) { - return dc.Query(query, args...) -} - -func (dc *DbContent) QueryRow(query string, args ...interface{}) *sql.Row { - return dc.QueryRow(query, args...) -} - //Open 开始一个数据库连接 -func Open(driverName string, dataSourceName string) (*DbContent, error) { - db, err := sql.Open(driverName, dataSourceName) +func Open(driverName string, dataSourceName string) (*model.AormDB, error) { + sqlDB, err := sql.Open(driverName, dataSourceName) if err != nil { - return &DbContent{}, err + return &model.AormDB{}, err } - err2 := db.Ping() + err2 := sqlDB.Ping() if err2 != nil { - return &DbContent{}, err2 + return &model.AormDB{}, err2 } - return &DbContent{ - DriverName: driverName, - DbLink: db, + return &model.AormDB{ + Driver: driverName, + SqlDB: sqlDB, }, nil } @@ -61,12 +30,11 @@ func Store(destList ...interface{}) { } // Db 开始一个数据库操作 -func Db(linkCommon ...model.LinkCommon) *builder.Builder { +func Db(linkCommon model.LinkCommon) *builder.Builder { b := &builder.Builder{} - if len(linkCommon) > 0 { - b.LinkCommon = linkCommon[0] - } + b.LinkCommon = linkCommon + b.Debug(linkCommon.GetDebugMode()) return b } diff --git a/builder/crud.go b/builder/crud.go index 9fd1a95..8bdc0ed 100644 --- a/builder/crud.go +++ b/builder/crud.go @@ -34,7 +34,6 @@ const RawEq = "RawEq" // Builder 查询记录所需要的条件 type Builder struct { LinkCommon model.LinkCommon - driverName string table interface{} tableAlias string @@ -63,12 +62,6 @@ func (b *Builder) Debug(isDebug bool) *Builder { return b } -// Driver 驱动类型 -func (b *Builder) Driver(driverName string) *Builder { - b.driverName = driverName - return b -} - // Distinct 过滤重复记录 func (b *Builder) Distinct(distinct bool) *Builder { b.distinct = distinct @@ -99,7 +92,7 @@ func (b *Builder) Insert(dest interface{}) (int64, error) { key, tagMap := getFieldNameByReflect(typeOf.Elem().Field(i)) //如果是Postgres数据库,寻找主键 - if b.driverName == model.Postgres { + if b.LinkCommon.DriverName() == model.Postgres { if _, ok := tagMap["primary"]; ok { primaryKey = key } @@ -115,15 +108,15 @@ func (b *Builder) Insert(dest interface{}) (int64, error) { } } - sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf) + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" + sql := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf) + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" - if b.driverName == model.Mssql { - return b.insertForMssqlOrPostgres(sqlStr+"; SELECT SCOPE_IDENTITY()", args...) - } else if b.driverName == model.Postgres { - sqlStr = convertToPostgresSql(sqlStr) - return b.insertForMssqlOrPostgres(sqlStr+" RETURNING "+primaryKey, args...) + if b.LinkCommon.DriverName() == model.Mssql { + return b.insertForMssqlOrPostgres(sql+"; SELECT SCOPE_IDENTITY()", args...) + } else if b.LinkCommon.DriverName() == model.Postgres { + sql = convertToPostgresSql(sql) + return b.insertForMssqlOrPostgres(sql+" RETURNING "+primaryKey, args...) } else { - return b.insertForCommon(sqlStr, args...) + return b.insertForCommon(sql, args...) } } @@ -196,7 +189,7 @@ func (b *Builder) InsertBatch(values interface{}) (int64, error) { sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf.Index(0)) + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") - if b.driverName == model.Postgres { + if b.LinkCommon.DriverName() == model.Postgres { sqlStr = convertToPostgresSql(sqlStr) } @@ -353,7 +346,7 @@ func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder { // Truncate 清空记录 func (b *Builder) Truncate() (int64, error) { sqlStr := "" - if b.driverName == model.Sqlite3 { + if b.LinkCommon.DriverName() == model.Sqlite3 { sqlStr = "DELETE FROM " + getTableNameByTable(b.table) } else { sqlStr = "TRUNCATE TABLE " + getTableNameByTable(b.table) @@ -373,7 +366,7 @@ func (b *Builder) RawSql(sql string, args ...interface{}) *Builder { func (b *Builder) GetRows() (*sql.Rows, error) { sql, args := b.GetSqlAndParams() - if b.driverName == model.Postgres { + if b.LinkCommon.DriverName() == model.Postgres { sql = convertToPostgresSql(sql) } @@ -398,7 +391,7 @@ func (b *Builder) GetRows() (*sql.Rows, error) { // Exec 通用执行-新增,更新,删除 func (b *Builder) Exec() (sql.Result, error) { - if b.driverName == model.Postgres { + if b.LinkCommon.DriverName() == model.Postgres { b.sql = convertToPostgresSql(b.sql) } @@ -432,7 +425,7 @@ func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving boo } //如果是mssql或者Postgres,并且来自having的话,需要特殊处理 - if (b.driverName == model.Mssql || b.driverName == model.Postgres) && isFromHaving { + if (b.LinkCommon.DriverName() == model.Mssql || b.LinkCommon.DriverName() == model.Postgres) && isFromHaving { fieldNameCurrent := getFieldName(where[i].Field) for m := 0; m < len(b.selectList); m++ { if fieldNameCurrent == getFieldName(b.selectList[m].FieldNew) { @@ -453,7 +446,7 @@ func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving boo } } else { if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { - if b.driverName == model.Sqlite3 { + if b.LinkCommon.DriverName() == model.Sqlite3 { whereList = append(whereList, allFieldName+" "+where[i].Opt+" "+"?") } else { switch where[i].Val.(type) { @@ -516,9 +509,9 @@ func (b *Builder) whereAndHaving(where []WhereItem, args []any, isFromHaving boo } func (b *Builder) getConcatForFloat(vars ...string) string { - if b.driverName == model.Sqlite3 { + if b.LinkCommon.DriverName() == model.Sqlite3 { return strings.Join(vars, "||") - } else if b.driverName == model.Postgres { + } else if b.LinkCommon.DriverName() == model.Postgres { return vars[0] } else { return "CONCAT(" + strings.Join(vars, ",") + ")" @@ -526,7 +519,7 @@ func (b *Builder) getConcatForFloat(vars ...string) string { } func (b *Builder) getConcatForLike(vars ...string) string { - if b.driverName == model.Sqlite3 || b.driverName == model.Postgres { + if b.LinkCommon.DriverName() == model.Sqlite3 || b.LinkCommon.DriverName() == model.Postgres { return strings.Join(vars, "||") } else { return "CONCAT(" + strings.Join(vars, ",") + ")" @@ -564,7 +557,7 @@ func (b *Builder) GetSqlAndParams() (string, []interface{}) { // execAffected 通用执行-更新,删除 func (b *Builder) execAffected(sql string, args ...interface{}) (int64, error) { - if b.driverName == model.Postgres { + if b.LinkCommon.DriverName() == model.Postgres { sql = convertToPostgresSql(sql) } diff --git a/builder/handle.go b/builder/handle.go index 62b9fa1..cc93735 100644 --- a/builder/handle.go +++ b/builder/handle.go @@ -168,7 +168,7 @@ func (b *Builder) handleLimit(paramList []any) (string, []any) { } str := "" - if b.driverName == model.Postgres { + if b.LinkCommon.DriverName() == model.Postgres { paramList = append(paramList, b.limitItem.pageSize) paramList = append(paramList, b.limitItem.offset) @@ -178,7 +178,7 @@ func (b *Builder) handleLimit(paramList []any) (string, []any) { paramList = append(paramList, b.limitItem.pageSize) str = " Limit ?,? " - if b.driverName == model.Mssql { + if b.LinkCommon.DriverName() == model.Mssql { str = " offset ? rows fetch next ? rows only " } } diff --git a/migrate_mssql/migrate.go b/migrate_mssql/migrate.go index 5f666ba..97eb0f2 100644 --- a/migrate_mssql/migrate.go +++ b/migrate_mssql/migrate.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" - "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" "reflect" "strconv" @@ -33,20 +32,14 @@ type Index struct { //MigrateExecutor 定义结构 type MigrateExecutor struct { - //驱动名字 - DriverName string - - //表属性 - OpinionList []model.OpinionItem - //执行者 - Ex *builder.Builder + Builder *builder.Builder } //ShowCreateTable 查看创建表的ddl func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { var str string - mm.Ex.RawSql("show create table "+tableName).Value("Create Table", &str) + mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) return str } @@ -142,7 +135,7 @@ func (mm *MigrateExecutor) getIndexesFromCode(typeOf reflect.Type, tableFromCode func (mm *MigrateExecutor) getDbName() (string, error) { //获取数据库名称 var dbName string - err := mm.Ex.RawSql("Select Name as db_name From Master..SysDataBases Where DbId=(Select Dbid From Master..SysProcesses Where Spid = @@spid)").Value("db_name", &dbName) + err := mm.Builder.RawSql("Select Name as db_name From Master..SysDataBases Where DbId=(Select Dbid From Master..SysProcesses Where Spid = @@spid)").Value("db_name", &dbName) if err != nil { return "", err } @@ -153,7 +146,7 @@ func (mm *MigrateExecutor) getDbName() (string, error) { func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { sql := "SELECT Name as TABLE_NAME FROM SysObjects Where XType='U' and Name =" + "'" + tableName + "'" var dataList []Table - mm.Ex.RawSql(sql).GetMany(&dataList) + mm.Builder.RawSql(sql).GetMany(&dataList) return dataList } @@ -177,7 +170,7 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C "Left Join sys.extended_properties F On D.id=F.major_id and F.minor_id=0 " + "Order By A.id,A.colorder" - mm.Ex.RawSql(sqlColumn).GetMany(&columnsFromDb) + mm.Builder.RawSql(sqlColumn).GetMany(&columnsFromDb) return columnsFromDb } @@ -203,7 +196,7 @@ func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { "AND t.name = '" + tableName + "'" var indexesFromDb []Index - mm.Ex.RawSql(sqlIndex).GetMany(&indexesFromDb) + mm.Builder.RawSql(sqlIndex).GetMany(&indexesFromDb) return indexesFromDb } @@ -220,7 +213,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) fmt.Println(sql) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -232,7 +225,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -261,7 +254,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if !keyMatch || indexCode.NonUnique.Int64 != indexDb.NonUnique.Int64 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -273,7 +266,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -298,7 +291,7 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co sqlStr := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" - _, err := mm.Ex.RawSql(sqlStr).Exec() + _, err := mm.Builder.RawSql(sqlStr).Exec() if err != nil { fmt.Println(err) } else { @@ -306,17 +299,6 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co } } -func (mm *MigrateExecutor) getOpinionVal(key string, def string) string { - opinions := mm.OpinionList - for i := 0; i < len(opinions); i++ { - opinionItem := opinions[i] - if opinionItem.Key == key { - def = opinionItem.Val - } - } - return def -} - func getTagMap(fieldTag string) map[string]string { var fieldMap = make(map[string]string) if "" != fieldTag { diff --git a/migrate_mysql/migrate.go b/migrate_mysql/migrate.go index ec59ff9..f88b13a 100644 --- a/migrate_mysql/migrate.go +++ b/migrate_mysql/migrate.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" - "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" "reflect" "strconv" @@ -35,20 +34,14 @@ type Index struct { //MigrateExecutor 定义结构 type MigrateExecutor struct { - //驱动名字 - DriverName string - - //表属性 - OpinionList []model.OpinionItem - //执行者 - Ex *builder.Builder + Builder *builder.Builder } //ShowCreateTable 查看创建表的ddl func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { var str string - mm.Ex.RawSql("show create table "+tableName).Value("Create Table", &str) + mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) return str } @@ -166,7 +159,7 @@ func (mm *MigrateExecutor) getIndexesFromCode(typeOf reflect.Type, tableFromCode func (mm *MigrateExecutor) getDbName() (string, error) { //获取数据库名称 var dbName string - err := mm.Ex.RawSql("SELECT DATABASE()").Value("DATABASE()", &dbName) + err := mm.Builder.RawSql("SELECT DATABASE()").Value("DATABASE()", &dbName) if err != nil { return "", err } @@ -177,7 +170,7 @@ func (mm *MigrateExecutor) getDbName() (string, error) { func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { sql := "SELECT TABLE_NAME,ENGINE,TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" var dataList []Table - mm.Ex.RawSql(sql).GetMany(&dataList) + mm.Builder.RawSql(sql).GetMany(&dataList) for i := 0; i < len(dataList); i++ { dataList[i].TableComment = null.StringFrom("'" + dataList[i].TableComment.String + "'") } @@ -189,7 +182,7 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C var columnsFromDb []Column sqlColumn := "SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH as Max_Length,COLUMN_DEFAULT,COLUMN_COMMENT,EXTRA,IS_NULLABLE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA =" + "'" + dbName + "' AND TABLE_NAME =" + "'" + tableName + "'" - mm.Ex.RawSql(sqlColumn).GetMany(&columnsFromDb) + mm.Builder.RawSql(sqlColumn).GetMany(&columnsFromDb) for j := 0; j < len(columnsFromDb); j++ { if columnsFromDb[j].DataType.String == "text" && columnsFromDb[j].MaxLength.Int64 == 65535 { @@ -204,7 +197,7 @@ func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { sqlIndex := "SHOW INDEXES FROM " + tableName var indexsFromDb []Index - mm.Ex.RawSql(sqlIndex).GetMany(&indexsFromDb) + mm.Builder.RawSql(sqlIndex).GetMany(&indexsFromDb) return indexsFromDb } @@ -232,7 +225,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co columnCode.Extra.String != columnDb.Extra.String || columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -244,7 +237,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -263,7 +256,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co isFind = 1 if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -275,7 +268,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getIndexStr(indexCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -287,7 +280,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co func (mm *MigrateExecutor) modifyTableEngine(tableFromCode Table) { sql := "ALTER TABLE " + tableFromCode.TableName.String + " Engine " + tableFromCode.Engine.String - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -297,7 +290,7 @@ func (mm *MigrateExecutor) modifyTableEngine(tableFromCode Table) { func (mm *MigrateExecutor) modifyTableComment(tableFromCode Table) { sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -319,7 +312,7 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co } sql := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + " ENGINE " + tableFromCode.Engine.String + " COMMENT " + tableFromCode.TableComment.String + ";" - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -327,17 +320,6 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co } } -func (mm *MigrateExecutor) getOpinionVal(key string, def string) string { - opinions := mm.OpinionList - for i := 0; i < len(opinions); i++ { - opinionItem := opinions[i] - if opinionItem.Key == key { - def = opinionItem.Val - } - } - return def -} - func getTagMap(fieldTag string) map[string]string { var fieldMap = make(map[string]string) if "" != fieldTag { diff --git a/migrate_postgres/migrate.go b/migrate_postgres/migrate.go index 7c3bf0d..61e61c9 100644 --- a/migrate_postgres/migrate.go +++ b/migrate_postgres/migrate.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" - "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" "reflect" "regexp" @@ -43,26 +42,20 @@ type Index struct { //MigrateExecutor 定义结构 type MigrateExecutor struct { - //驱动名字 - DriverName string - - //表属性 - OpinionList []model.OpinionItem - //执行者 - Ex *builder.Builder + Builder *builder.Builder } //ShowCreateTable 查看创建表的ddl func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { var str string - mm.Ex.RawSql("show create table "+tableName).Value("Create Table", &str) + mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) return str } //MigrateCommon 迁移的主要过程 -func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) error { - tableFromCode := mm.getTableFromCode(tableName) +func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) error { + tableFromCode := mm.getTableFromCode(tableName, typeOf, valueOf) columnsFromCode := mm.getColumnsFromCode(typeOf) indexesFromCode := mm.getIndexesFromCode(typeOf, tableFromCode) @@ -85,12 +78,25 @@ func (mm *MigrateExecutor) MigrateCommon(tableName string, typeOf reflect.Type) return nil } -func (mm *MigrateExecutor) getTableFromCode(tableName string) Table { - var tableFromCode Table - tableFromCode.TableName = null.StringFrom(tableName) - tableFromCode.TableComment = null.StringFrom(mm.getOpinionVal("COMMENT", "")) +func (mm *MigrateExecutor) getTableFromCode(tableName string, typeOf reflect.Type, valueOf reflect.Value) Table { + table := Table{ + TableName: null.StringFrom(tableName), + TableComment: null.StringFrom("''"), + } - return tableFromCode + method, isSet := typeOf.MethodByName("TableOpinion") + if isSet { + var paramList []reflect.Value + paramList = append(paramList, valueOf) + valueList := method.Func.Call(paramList) + i := valueList[0].Interface() + m := i.(map[string]string) + + m["COMMENT"] = "'" + m["COMMENT"] + "'" + table.TableComment = null.StringFrom(m["COMMENT"]) + } + + return table } func (mm *MigrateExecutor) getColumnsFromCode(typeOf reflect.Type) []Column { @@ -153,7 +159,7 @@ func (mm *MigrateExecutor) getIndexesFromCode(typeOf reflect.Type, tableFromCode func (mm *MigrateExecutor) getDbName() (string, error) { //获取数据库名称 var dbName string - err := mm.Ex.RawSql("select current_database()").Value("current_database", &dbName) + err := mm.Builder.RawSql("select current_database()").Value("current_database", &dbName) if err != nil { return "", err } @@ -164,7 +170,7 @@ func (mm *MigrateExecutor) getDbName() (string, error) { func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { sql := "select a.relname as TABLE_NAME, b.description as TABLE_COMMENT from pg_class a left join (select * from pg_description where objsubid =0) b on a.oid = b.objoid where a.relname in (select tablename from pg_tables where schemaname = 'public' and tablename = " + "'" + tableName + "') order by a.relname asc" var dataList []Table - mm.Ex.RawSql(sql).GetMany(&dataList) + mm.Builder.RawSql(sql).GetMany(&dataList) for i := 0; i < len(dataList); i++ { dataList[i].TableComment = null.StringFrom("'" + dataList[i].TableComment.String + "'") } @@ -177,7 +183,7 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C sqlColumn := "select column_name,data_type,character_maximum_length as max_length,column_default,'' as COLUMN_COMMENT, is_nullable from information_schema.columns where table_schema='public' and table_name=" + "'" + tableName + "'" - mm.Ex.RawSql(sqlColumn).GetMany(&columnsFromDb) + mm.Builder.RawSql(sqlColumn).GetMany(&columnsFromDb) for j := 0; j < len(columnsFromDb); j++ { if columnsFromDb[j].DataType.String == "character varying" { @@ -199,7 +205,7 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { sqlIndex := "select * from pg_indexes where tablename=" + "'" + tableName + "'" var sqliteMasterList []PgIndexes - mm.Ex.RawSql(sqlIndex).GetMany(&sqliteMasterList) + mm.Builder.RawSql(sqlIndex).GetMany(&sqliteMasterList) var indexesFromDb []Index for i := 0; i < len(sqliteMasterList); i++ { @@ -252,7 +258,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co sql := "ALTER TABLE " + tableFromCode.TableName.String + " alter COLUMN " + getColumnStr(columnCode, "type") //fmt.Println(sql) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -264,7 +270,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode, "") - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -283,7 +289,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co isFind = 1 if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -301,7 +307,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co func (mm *MigrateExecutor) modifyTableComment(tableFromCode Table) { sql := "ALTER TABLE " + tableFromCode.TableName.String + " Comment " + tableFromCode.TableComment.String - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -326,7 +332,7 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co sql := "CREATE TABLE " + tableFromCode.TableName.String + " (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -349,7 +355,7 @@ func (mm *MigrateExecutor) createIndex(tableName string, index Index) { } sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")" - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -357,17 +363,6 @@ func (mm *MigrateExecutor) createIndex(tableName string, index Index) { } } -func (mm *MigrateExecutor) getOpinionVal(key string, def string) string { - opinions := mm.OpinionList - for i := 0; i < len(opinions); i++ { - opinionItem := opinions[i] - if opinionItem.Key == key { - def = opinionItem.Val - } - } - return def -} - func getTagMap(fieldTag string) map[string]string { var fieldMap = make(map[string]string) if "" != fieldTag { diff --git a/migrate_sqlite3/migrate.go b/migrate_sqlite3/migrate.go index 5cfbe48..7dce021 100644 --- a/migrate_sqlite3/migrate.go +++ b/migrate_sqlite3/migrate.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" - "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" "reflect" "regexp" @@ -41,20 +40,14 @@ type Index struct { //MigrateExecutor 定义结构 type MigrateExecutor struct { - //驱动名字 - DriverName string - - //表属性 - OpinionList []model.OpinionItem - //执行者 - Ex *builder.Builder + Builder *builder.Builder } //ShowCreateTable 查看创建表的ddl func (mm *MigrateExecutor) ShowCreateTable(tableName string) string { var str string - mm.Ex.RawSql("show create table "+tableName).Value("Create Table", &str) + mm.Builder.RawSql("show create table "+tableName).Value("Create Table", &str) return str } @@ -153,7 +146,7 @@ func (mm *MigrateExecutor) getDbName() (string, error) { func (mm *MigrateExecutor) getTableFromDb(dbName string, tableName string) []Table { sql := "select * from sqlite_master where type='table' and tbl_name=" + "'" + tableName + "'" var sqliteMasterList []SqliteMaster - mm.Ex.RawSql(sql).GetMany(&sqliteMasterList) + mm.Builder.RawSql(sql).GetMany(&sqliteMasterList) var dataList []Table for i := 0; i < len(sqliteMasterList); i++ { @@ -170,7 +163,7 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C var sqliteMaster SqliteMaster sqlColumn1 := "select * from sqlite_master where type='table' and tbl_name = " + "'" + tableName + "'" - mm.Ex.RawSql(sqlColumn1).GetOne(&sqliteMaster) + mm.Builder.RawSql(sqlColumn1).GetOne(&sqliteMaster) str := sqliteMaster.Sql.String str = strings.ReplaceAll(str, "\n", "") @@ -205,7 +198,7 @@ func (mm *MigrateExecutor) getColumnsFromDb(dbName string, tableName string) []C func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { sqlIndex := "select * from sqlite_master where type = 'index' and name not like '%sqlite_autoindex%' and tbl_name=" + "'" + tableName + "'" var sqliteMasterList []SqliteMaster - mm.Ex.RawSql(sqlIndex).GetMany(&sqliteMasterList) + mm.Builder.RawSql(sqlIndex).GetMany(&sqliteMasterList) var indexesFromDb []Index for i := 0; i < len(sqliteMasterList); i++ { @@ -229,7 +222,7 @@ func (mm *MigrateExecutor) getIndexesFromDb(tableName string) []Index { //查询是否有主键索引 sql := "select * from sqlite_master where type='table' and tbl_name=" + "'" + tableName + "'" var sqliteMaster SqliteMaster - mm.Ex.RawSql(sql).GetOne(&sqliteMaster) + mm.Builder.RawSql(sql).GetOne(&sqliteMaster) compileRegex := regexp.MustCompile("PRIMARY\\sKEY\\s\\((.*?)\\)") matchArr2 := compileRegex.FindAllStringSubmatch(sqliteMaster.Sql.String, -1) @@ -258,7 +251,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co columnCode.ColumnDefault.String != columnDb.ColumnDefault.String { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getColumnStr(columnCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(sql) fmt.Println(err) @@ -271,7 +264,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co if isFind == 0 { sql := "ALTER TABLE " + tableFromCode.TableName.String + " ADD " + getColumnStr(columnCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(sql) fmt.Println(err) @@ -291,7 +284,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co isFind = 1 if indexCode.KeyName != indexDb.KeyName || indexCode.NonUnique != indexDb.NonUnique { sql := "ALTER TABLE " + tableFromCode.TableName.String + " MODIFY " + getIndexStr(indexCode) - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -324,7 +317,7 @@ func (mm *MigrateExecutor) createTable(tableFromCode Table, columnsFromCode []Co //创建表结构与主键索引 sql := "CREATE TABLE `" + tableFromCode.TableName.String + "` (\n" + strings.Join(fieldArr, ",\n") + "\n) " + ";" - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -347,7 +340,7 @@ func (mm *MigrateExecutor) createIndex(tableName string, index Index) { } sql := "CREATE " + keyType + " INDEX " + index.KeyName.String + " on " + tableName + " (" + index.ColumnName.String + ")" - _, err := mm.Ex.RawSql(sql).Exec() + _, err := mm.Builder.RawSql(sql).Exec() if err != nil { fmt.Println(err) } else { @@ -355,17 +348,6 @@ func (mm *MigrateExecutor) createIndex(tableName string, index Index) { } } -func (mm *MigrateExecutor) getOpinionVal(key string, def string) string { - opinions := mm.OpinionList - for i := 0; i < len(opinions); i++ { - opinionItem := opinions[i] - if opinionItem.Key == key { - def = opinionItem.Val - } - } - return def -} - func getTagMap(fieldTag string) map[string]string { var fieldMap = make(map[string]string) if "" != fieldTag { diff --git a/migrator/migrator.go b/migrator/migrator.go index c0cd28f..8279235 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -15,36 +15,13 @@ import ( type Migrator struct { //数据库操作连接 LinkCommon model.LinkCommon - - //驱动名字 - driverName string - - //表属性 - opinionList []model.OpinionItem -} - -func (mi *Migrator) Driver(driverName string) *Migrator { - mi.driverName = driverName - return mi -} - -func (mi *Migrator) Opinion(key string, val string) *Migrator { - if key == "COMMENT" { - val = "'" + val + "'" - } - - mi.opinionList = append(mi.opinionList, model.OpinionItem{Key: key, Val: val}) - - return mi } //ShowCreateTable 获取创建表的ddl func (mi *Migrator) ShowCreateTable(tableName string) string { - if mi.driverName == model.Mysql { + if mi.LinkCommon.DriverName() == model.Mysql { me := migrate_mysql.MigrateExecutor{ - DriverName: mi.driverName, - OpinionList: mi.opinionList, - Ex: &builder.Builder{ + Builder: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } @@ -72,48 +49,40 @@ func (mi *Migrator) Migrate(tableName string, dest interface{}) { } func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type, valueOf reflect.Value) { - if mi.driverName == model.Mssql { + if mi.LinkCommon.DriverName() == model.Mssql { me := migrate_mssql.MigrateExecutor{ - DriverName: mi.driverName, - OpinionList: mi.opinionList, - Ex: &builder.Builder{ + Builder: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } me.MigrateCommon(tableName, typeOf) } - if mi.driverName == model.Mysql { + if mi.LinkCommon.DriverName() == model.Mysql { me := migrate_mysql.MigrateExecutor{ - DriverName: mi.driverName, - OpinionList: mi.opinionList, - Ex: &builder.Builder{ + Builder: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } me.MigrateCommon(tableName, typeOf, valueOf) } - if mi.driverName == model.Sqlite3 { + if mi.LinkCommon.DriverName() == model.Sqlite3 { me := migrate_sqlite3.MigrateExecutor{ - DriverName: mi.driverName, - OpinionList: mi.opinionList, - Ex: &builder.Builder{ + Builder: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } me.MigrateCommon(tableName, typeOf) } - if mi.driverName == model.Postgres { + if mi.LinkCommon.DriverName() == model.Postgres { me := migrate_postgres.MigrateExecutor{ - DriverName: mi.driverName, - OpinionList: mi.opinionList, - Ex: &builder.Builder{ + Builder: &builder.Builder{ LinkCommon: mi.LinkCommon, }, } - me.MigrateCommon(tableName, typeOf) + me.MigrateCommon(tableName, typeOf, valueOf) } } diff --git a/model/AormDB.go b/model/AormDB.go new file mode 100644 index 0000000..eefbef1 --- /dev/null +++ b/model/AormDB.go @@ -0,0 +1,52 @@ +package model + +import "database/sql" + +// AormDB 数据库连接与数据库类型 +type AormDB struct { + Driver string + DebugMode bool + SqlDB *sql.DB +} + +//Begin 开始一个事务 +func (db *AormDB) Begin() *AormTx { + SqlTx, _ := db.SqlDB.Begin() + + return &AormTx{ + driver: db.Driver, + debugMode: db.DebugMode, + + sqlTx: SqlTx, + } +} + +//SetDebugMode 获取调试模式 +func (db *AormDB) SetDebugMode(debugMode bool) { + db.DebugMode = debugMode +} + +//GetDebugMode 获取调试模式 +func (db *AormDB) GetDebugMode() bool { + return db.DebugMode +} + +func (db *AormDB) DriverName() string { + return db.Driver +} + +func (db *AormDB) Exec(query string, args ...interface{}) (sql.Result, error) { + return db.SqlDB.Exec(query, args...) +} + +func (db *AormDB) Prepare(query string) (*sql.Stmt, error) { + return db.SqlDB.Prepare(query) +} + +func (db *AormDB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return db.SqlDB.Query(query, args...) +} + +func (db *AormDB) QueryRow(query string, args ...interface{}) *sql.Row { + return db.SqlDB.QueryRow(query, args...) +} diff --git a/model/AormTx.go b/model/AormTx.go new file mode 100644 index 0000000..b9dd9b6 --- /dev/null +++ b/model/AormTx.go @@ -0,0 +1,42 @@ +package model + +import "database/sql" + +type AormTx struct { + driver string + debugMode bool + sqlTx *sql.Tx +} + +//GetDebugMode 获取调试状态 +func (tx *AormTx) GetDebugMode() bool { + return tx.debugMode +} + +func (tx *AormTx) DriverName() string { + return tx.driver +} + +func (tx *AormTx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.sqlTx.Exec(query, args...) +} + +func (tx *AormTx) Prepare(query string) (*sql.Stmt, error) { + return tx.sqlTx.Prepare(query) +} + +func (tx *AormTx) Query(query string, args ...interface{}) (*sql.Rows, error) { + return tx.sqlTx.Query(query, args...) +} + +func (tx *AormTx) QueryRow(query string, args ...interface{}) *sql.Row { + return tx.sqlTx.QueryRow(query, args...) +} + +func (tx *AormTx) Rollback() error { + return tx.sqlTx.Rollback() +} + +func (tx *AormTx) Commit() error { + return tx.sqlTx.Commit() +} diff --git a/model/model.go b/model/model.go index b6f339b..5affde9 100644 --- a/model/model.go +++ b/model/model.go @@ -3,17 +3,14 @@ package model import "database/sql" type LinkCommon interface { + GetDebugMode() bool + DriverName() string Exec(query string, args ...interface{}) (sql.Result, error) Prepare(query string) (*sql.Stmt, error) Query(query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row } -type OpinionItem struct { - Key string - Val string -} - type FieldInfo struct { TablePointer uintptr Name string diff --git a/test/aorm_test.go b/test/aorm_test.go index 9a4ae1a..74ddd11 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -1,7 +1,6 @@ package test import ( - "database/sql" "fmt" _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" @@ -96,134 +95,136 @@ func TestAll(t *testing.T) { aorm.Store(&articleVO) aorm.Store(&personAge, &personWithArticleCount) - //var dbList = []*aorm.DbContent{ - // testMysqlConnect(), - // testSqlite3Connect(), - // testPostgresConnect(), - // testMssqlConnect(), - //} - // - //for i := 0; i < len(dbList); i++ { - // dbItem := dbList[i] - // - // testMigrate(dbItem.DriverName, dbItem.DbLink) - // - // testShowCreateTable(dbItem.DriverName, dbItem.DbLink) - // - // id := testInsert(dbItem.DriverName, dbItem.DbLink) - // testInsertBatch(dbItem.DriverName, dbItem.DbLink) - // testGetOne(dbItem.DriverName, dbItem.DbLink, id) - // testGetMany(dbItem.DriverName, dbItem.DbLink) - // testUpdate(dbItem.DriverName, dbItem.DbLink, id) - // - // isExists := testExists(dbItem.DriverName, dbItem.DbLink, id) - // if isExists != true { - // panic("应该存在,但是数据库不存在") - // } - // - // testDelete(dbItem.DriverName, dbItem.DbLink, id) - // isExists2 := testExists(dbItem.DriverName, dbItem.DbLink, id) - // if isExists2 == true { - // panic("应该不存在,但是数据库存在") - // } - // - // id2 := testInsert(dbItem.DriverName, dbItem.DbLink) - // testTable(dbItem.DriverName, dbItem.DbLink) - // testSelect(dbItem.DriverName, dbItem.DbLink) - // testSelectWithSub(dbItem.DriverName, dbItem.DbLink) - // testWhereWithSub(dbItem.DriverName, dbItem.DbLink) - // testWhere(dbItem.DriverName, dbItem.DbLink) - // testJoin(dbItem.DriverName, dbItem.DbLink) - // testJoinWithAlias(dbItem.DriverName, dbItem.DbLink) - // - // testGroupBy(dbItem.DriverName, dbItem.DbLink) - // testHaving(dbItem.DriverName, dbItem.DbLink) - // testOrderBy(dbItem.DriverName, dbItem.DbLink) - // testLimit(dbItem.DriverName, dbItem.DbLink) - // testLock(dbItem.DriverName, dbItem.DbLink, id2) - // - // testIncrement(dbItem.DriverName, dbItem.DbLink, id2) - // testDecrement(dbItem.DriverName, dbItem.DbLink, id2) - // - // testValue(dbItem.DriverName, dbItem.DbLink, id2) - // testPluck(dbItem.DriverName, dbItem.DbLink) - // - // testCount(dbItem.DriverName, dbItem.DbLink) - // testSum(dbItem.DriverName, dbItem.DbLink) - // testAvg(dbItem.DriverName, dbItem.DbLink) - // testMin(dbItem.DriverName, dbItem.DbLink) - // testMax(dbItem.DriverName, dbItem.DbLink) - // - // testDistinct(dbItem.DriverName, dbItem.DbLink) - // - // testRawSql(dbItem.DriverName, dbItem.DbLink, id2) - // - // testTransaction(dbItem.DriverName, dbItem.DbLink) - // testTruncate(dbItem.DriverName, dbItem.DbLink) - //} - // - //testPreview() + var dbList = []*model.AormDB{ + testMysqlConnect(), + testSqlite3Connect(), + testPostgresConnect(), + testMssqlConnect(), + } + for i := 0; i < len(dbList); i++ { + dbItem := dbList[i] + + testMigrate(dbItem) + testShowCreateTable(dbItem) + + id := testInsert(dbItem) + testInsertBatch(dbItem) + testGetOne(dbItem, id) + testGetMany(dbItem) + testUpdate(dbItem, id) + + isExists := testExists(dbItem, id) + if isExists != true { + panic("应该存在,但是数据库不存在") + } + + testDelete(dbItem, id) + isExists2 := testExists(dbItem, id) + if isExists2 == true { + panic("应该不存在,但是数据库存在") + } + + id2 := testInsert(dbItem) + testTable(dbItem) + testSelect(dbItem) + testSelectWithSub(dbItem) + testWhereWithSub(dbItem) + testWhere(dbItem) + testJoin(dbItem) + testJoinWithAlias(dbItem) + + testGroupBy(dbItem) + testHaving(dbItem) + testOrderBy(dbItem) + testLimit(dbItem) + testLock(dbItem, id2) + + testIncrement(dbItem, id2) + testDecrement(dbItem, id2) + + testValue(dbItem, id2) + testPluck(dbItem) + + testCount(dbItem) + testSum(dbItem) + testAvg(dbItem) + testMin(dbItem) + testMax(dbItem) + + testDistinct(dbItem) + testRawSql(dbItem, id2) + + testTransaction(dbItem) + testTruncate(dbItem) + } + + testPreview() testDbContent() } -func testSqlite3Connect() *aorm.DbContent { - sqlite3Content, sqlite3Err := aorm.Open("sqlite3", "test.db") +func testSqlite3Connect() *model.AormDB { + sqlite3Content, sqlite3Err := aorm.Open(model.Sqlite3, "test.db") if sqlite3Err != nil { panic(sqlite3Err) } + sqlite3Content.SetDebugMode(false) return sqlite3Content } -func testMysqlConnect() *aorm.DbContent { +func testMysqlConnect() *model.AormDB { username := "root" password := "root" hostname := "localhost" port := "3306" dbname := "database_name" - mysqlContent, mysqlErr := aorm.Open("mysql", username+":"+password+"@tcp("+hostname+":"+port+")/"+dbname+"?charset=utf8mb4&parseTime=True&loc=Local") + mysqlContent, mysqlErr := aorm.Open(model.Mysql, username+":"+password+"@tcp("+hostname+":"+port+")/"+dbname+"?charset=utf8mb4&parseTime=True&loc=Local") if mysqlErr != nil { panic(mysqlErr) } + mysqlContent.SetDebugMode(false) return mysqlContent } -func testPostgresConnect() *aorm.DbContent { +func testPostgresConnect() *model.AormDB { psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", "localhost", 5432, "postgres", "root", "postgres") - postgresContent, postgresErr := aorm.Open("postgres", psqlInfo) + postgresContent, postgresErr := aorm.Open(model.Postgres, psqlInfo) if postgresErr != nil { panic(postgresErr) } + postgresContent.SetDebugMode(false) + return postgresContent } -func testMssqlConnect() *aorm.DbContent { +func testMssqlConnect() *model.AormDB { mssqlInfo := fmt.Sprintf("server=%s;database=%s;user id=%s;password=%s;port=%d;encrypt=disable", "localhost", "database_name", "sa", "root", 1433) - mssqlContent, mssqlErr := aorm.Open("mssql", mssqlInfo) + mssqlContent, mssqlErr := aorm.Open(model.Mssql, mssqlInfo) if mssqlErr != nil { panic(mssqlErr) } + mssqlContent.SetDebugMode(false) return mssqlContent } -func testMigrate(driver string, db *sql.DB) { - aorm.Migrator(db).Driver(driver).AutoMigrate(&person, &article, &student) +func testMigrate(db *model.AormDB) { + aorm.Migrator(db).AutoMigrate(&person, &article, &student) - aorm.Migrator(db).Driver(driver).Migrate("person_1", &person) + aorm.Migrator(db).Migrate("person_1", &person) } -func testShowCreateTable(driver string, db *sql.DB) { - aorm.Migrator(db).Driver(driver).ShowCreateTable("person") +func testShowCreateTable(db *model.AormDB) { + aorm.Migrator(db).ShowCreateTable("person") } -func testInsert(driver string, db *sql.DB) int64 { +func testInsert(db *model.AormDB) int64 { obj := Person{ Name: null.StringFrom("Alice"), Sex: null.BoolFrom(true), @@ -234,55 +235,55 @@ func testInsert(driver string, db *sql.DB) int64 { Test: null.FloatFrom(2), } - id, errInsert := aorm.Db(db).Debug(false).Driver(driver).Insert(&obj) + id, errInsert := aorm.Db(db).Insert(&obj) if errInsert != nil { - panic(driver + " testInsert " + "found err: " + errInsert.Error()) + panic(db.DriverName() + " testInsert " + "found err: " + errInsert.Error()) } - aorm.Db(db).Debug(false).Driver(driver).Insert(&Article{ + aorm.Db(db).Insert(&Article{ Type: null.IntFrom(0), PersonId: null.IntFrom(id), ArticleBody: null.StringFrom("文章内容"), }) var personItem Person - err := aorm.Db(db).Table(&person).Debug(false).Driver(driver).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) + err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) if err != nil { fmt.Println(err.Error()) } if obj.Name.String != personItem.Name.String { - fmt.Println(driver + ",Name not match, expected: " + obj.Name.String + " ,but real is : " + personItem.Name.String) + fmt.Println(db.DriverName() + ",Name not match, expected: " + obj.Name.String + " ,but real is : " + personItem.Name.String) } if obj.Sex.Bool != personItem.Sex.Bool { - fmt.Println(driver + ",Sex not match, expected: " + fmt.Sprintf("%v", obj.Sex.Bool) + " ,but real is : " + fmt.Sprintf("%v", personItem.Sex.Bool)) + fmt.Println(db.DriverName() + ",Sex not match, expected: " + fmt.Sprintf("%v", obj.Sex.Bool) + " ,but real is : " + fmt.Sprintf("%v", personItem.Sex.Bool)) } if obj.Age.Int64 != personItem.Age.Int64 { - fmt.Println(driver + ",Age not match, expected: " + fmt.Sprintf("%v", obj.Age.Int64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Age.Int64)) + fmt.Println(db.DriverName() + ",Age not match, expected: " + fmt.Sprintf("%v", obj.Age.Int64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Age.Int64)) } if obj.Type.Int64 != personItem.Type.Int64 { - fmt.Println(driver + ",Type not match, expected: " + fmt.Sprintf("%v", obj.Type.Int64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Type.Int64)) + fmt.Println(db.DriverName() + ",Type not match, expected: " + fmt.Sprintf("%v", obj.Type.Int64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Type.Int64)) } if obj.Money.Float64 != personItem.Money.Float64 { - fmt.Println(driver + ",Money not match, expected: " + fmt.Sprintf("%v", obj.Money.Float64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Money.Float64)) + fmt.Println(db.DriverName() + ",Money not match, expected: " + fmt.Sprintf("%v", obj.Money.Float64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Money.Float64)) } if obj.Test.Float64 != personItem.Test.Float64 { - fmt.Println(driver + ",Test not match, expected: " + fmt.Sprintf("%v", obj.Test.Float64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Test.Float64)) + fmt.Println(db.DriverName() + ",Test not match, expected: " + fmt.Sprintf("%v", obj.Test.Float64) + " ,but real is : " + fmt.Sprintf("%v", personItem.Test.Float64)) } //测试非id主键 - aorm.Db(db).Debug(false).Driver(driver).Insert(&Student{ + aorm.Db(db).Insert(&Student{ Name: null.StringFrom("new student"), }) return id } -func testInsertBatch(driver string, db *sql.DB) int64 { +func testInsertBatch(db *model.AormDB) int64 { var batch []*Person batch = append(batch, &Person{ Name: null.StringFrom("Alice"), @@ -304,85 +305,84 @@ func testInsertBatch(driver string, db *sql.DB) int64 { Test: null.FloatFrom(200.15987654321987654321), }) - count, err := aorm.Db(db).Debug(false).Driver(driver).InsertBatch(&batch) + count, err := aorm.Db(db).InsertBatch(&batch) if err != nil { - panic(driver + " testInsertBatch " + "found err:" + err.Error()) + panic(db.DriverName() + " testInsertBatch " + "found err:" + err.Error()) } return count } -func testGetOne(driver string, db *sql.DB, id int64) { +func testGetOne(db *model.AormDB, id int64) { var personItem Person - errFind := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).GetOne(&personItem) + errFind := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).GetOne(&personItem) if errFind != nil { - panic(driver + "testGetOne" + "found err") + panic(db.DriverName() + "testGetOne" + "found err") } } -func testGetMany(driver string, db *sql.DB) { +func testGetMany(db *model.AormDB) { var list []Person - errSelect := aorm.Db(db).Driver(driver).Debug(false).Table(&person).WhereEq(&person.Type, 0).GetMany(&list) + errSelect := aorm.Db(db).Table(&person).WhereEq(&person.Type, 0).GetMany(&list) if errSelect != nil { - panic(driver + " testGetMany " + "found err:" + errSelect.Error()) + panic(db.DriverName() + " testGetMany " + "found err:" + errSelect.Error()) } } -func testUpdate(driver string, db *sql.DB, id int64) { - _, errUpdate := aorm.Db(db).Debug(false).Driver(driver).WhereEq(&person.Id, id).Update(&Person{Name: null.StringFrom("Bob")}) +func testUpdate(db *model.AormDB, id int64) { + _, errUpdate := aorm.Db(db).WhereEq(&person.Id, id).Update(&Person{Name: null.StringFrom("Bob")}) if errUpdate != nil { - panic(driver + "testUpdate" + "found err") + panic(db.DriverName() + "testUpdate" + "found err") } } -func testDelete(driver string, db *sql.DB, id int64) { - _, errDelete := aorm.Db(db).Driver(driver).Debug(false).Table(&person).WhereEq(&person.Id, id).Delete() +func testDelete(db *model.AormDB, id int64) { + _, errDelete := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Delete() if errDelete != nil { - panic(driver + "testDelete" + "found err") + panic(db.DriverName() + "testDelete" + "found err") } - _, errDelete2 := aorm.Db(db).Driver(driver).Debug(false).Delete(&Person{ + _, errDelete2 := aorm.Db(db).Delete(&Person{ Id: null.IntFrom(id), }) if errDelete2 != nil { - panic(driver + "testDelete" + "found err") + panic(db.DriverName() + "testDelete" + "found err") } } -func testExists(driver string, db *sql.DB, id int64) bool { - exists, err := aorm.Db(db).Driver(driver).Debug(false).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).Exists() +func testExists(db *model.AormDB, id int64) bool { + exists, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).Exists() if err != nil { - panic(driver + " testExists " + "found err:" + err.Error()) + panic(db.DriverName() + " testExists " + "found err:" + err.Error()) } return exists } -func testTable(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Driver(driver).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) +func testTable(db *model.AormDB) { + _, err := aorm.Db(db).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) if err != nil { - panic(driver + " testTable " + "found err:" + err.Error()) + panic(db.DriverName() + " testTable " + "found err:" + err.Error()) } - _, err2 := aorm.Db(db).Debug(false).Driver(driver).Table(&person).Insert(&Person{Name: null.StringFrom("Cherry")}) + _, err2 := aorm.Db(db).Table(&person).Insert(&Person{Name: null.StringFrom("Cherry")}) if err2 != nil { - panic(driver + " testTable " + "found err:" + err2.Error()) + panic(db.DriverName() + " testTable " + "found err:" + err2.Error()) } } -func testSelect(driver string, db *sql.DB) { +func testSelect(db *model.AormDB) { var listByFiled []Person - err := aorm.Db(db).Debug(false).Driver(driver).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) + err := aorm.Db(db).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) if err != nil { - panic(driver + " testSelect " + "found err:" + err.Error()) + panic(db.DriverName() + " testSelect " + "found err:" + err.Error()) } } -func testSelectWithSub(driver string, db *sql.DB) { +func testSelectWithSub(db *model.AormDB) { var listByFiled []PersonWithArticleCount - sub := aorm.Db().Table(&article).SelectCount(&article.Id, "article_count_tem").WhereRawEq(&article.PersonId, &person.Id) - err := aorm.Db(db).Debug(false). - Driver(driver). + sub := aorm.Db(db).Table(&article).SelectCount(&article.Id, "article_count_tem").WhereRawEq(&article.PersonId, &person.Id) + err := aorm.Db(db). SelectExp(&sub, &personWithArticleCount.ArticleCount). SelectAll(&person). Table(&person). @@ -390,26 +390,25 @@ func testSelectWithSub(driver string, db *sql.DB) { GetMany(&listByFiled) if err != nil { - panic(driver + " testSelectWithSub " + "found err:" + err.Error()) + panic(db.DriverName() + " testSelectWithSub " + "found err:" + err.Error()) } } -func testWhereWithSub(driver string, db *sql.DB) { +func testWhereWithSub(db *model.AormDB) { var listByFiled []Person - sub := aorm.Db().Table(&article).Driver(driver).SelectCount(&article.PersonId, "count_person_id").GroupBy(&article.PersonId).HavingGt("count_person_id", 0) - err := aorm.Db(db).Debug(false). + sub := aorm.Db(db).Table(&article).SelectCount(&article.PersonId, "count_person_id").GroupBy(&article.PersonId).HavingGt("count_person_id", 0) + err := aorm.Db(db). Table(&person). - Driver(driver). WhereIn(&person.Id, &sub). GetMany(&listByFiled) if err != nil { - panic(driver + " testWhereWithSub " + "found err:" + err.Error()) + panic(db.DriverName() + " testWhereWithSub " + "found err:" + err.Error()) } } -func testWhere(driver string, db *sql.DB) { +func testWhere(db *model.AormDB) { var listByWhere []Person - err := aorm.Db(db).Debug(false).Driver(driver).Table(&person).WhereArr([]builder.WhereItem{ + err := aorm.Db(db).Table(&person).WhereArr([]builder.WhereItem{ builder.GenWhereItem(&person.Type, builder.Eq, 0), builder.GenWhereItem(&person.Age, builder.In, []int{18, 20}), builder.GenWhereItem(&person.Money, builder.Between, []float64{100.1, 200.9}), @@ -417,13 +416,13 @@ func testWhere(driver string, db *sql.DB) { builder.GenWhereItem(&person.Name, builder.Like, []string{"%", "li", "%"}), }).GetMany(&listByWhere) if err != nil { - panic(driver + "testWhere" + "found err") + panic(db.DriverName() + "testWhere" + "found err") } } -func testJoin(driver string, db *sql.DB) { +func testJoin(db *model.AormDB) { var list2 []ArticleVO - err := aorm.Db(db).Debug(false).Driver(driver). + err := aorm.Db(db). Table(&article). LeftJoin( &person, @@ -437,13 +436,13 @@ func testJoin(driver string, db *sql.DB) { WhereIn(&person.Age, []int{18, 20}). GetMany(&list2) if err != nil { - panic(driver + " testWhere " + "found err " + err.Error()) + panic(db.DriverName() + " testWhere " + "found err " + err.Error()) } } -func testJoinWithAlias(driver string, db *sql.DB) { +func testJoinWithAlias(db *model.AormDB) { var list2 []ArticleVO - err := aorm.Db(db).Debug(false).Driver(driver). + err := aorm.Db(db). Table(&article, "o"). LeftJoin( &person, @@ -458,30 +457,29 @@ func testJoinWithAlias(driver string, db *sql.DB) { WhereIn(&person.Age, []int{18, 20}, "p"). GetMany(&list2) if err != nil { - panic(driver + " testWhere " + "found err " + err.Error()) + panic(db.DriverName() + " testWhere " + "found err " + err.Error()) } } -func testGroupBy(driver string, db *sql.DB) { +func testGroupBy(db *model.AormDB) { var personAgeItem PersonAge - err := aorm.Db(db).Debug(false). + err := aorm.Db(db). Table(&person). Select(&person.Age). SelectCount(&person.Age, &personAge.AgeCount). GroupBy(&person.Age). WhereEq(&person.Type, 0). - Driver(driver). OrderBy(&person.Age, builder.Desc). GetOne(&personAgeItem) if err != nil { - panic(driver + "testGroupBy" + "found err") + panic(db.DriverName() + "testGroupBy" + "found err") } } -func testHaving(driver string, db *sql.DB) { +func testHaving(db *model.AormDB) { var listByHaving []PersonAge - err := aorm.Db(db).Debug(false).Driver(driver). + err := aorm.Db(db). Table(&person). Select(&person.Age). SelectCount(&person.Age, &personAge.AgeCount). @@ -491,229 +489,225 @@ func testHaving(driver string, db *sql.DB) { HavingGt(&personAge.AgeCount, 4). GetMany(&listByHaving) if err != nil { - panic(driver + " testHaving " + "found err") + panic(db.DriverName() + " testHaving " + "found err") } } -func testOrderBy(driver string, db *sql.DB) { +func testOrderBy(db *model.AormDB) { var listByOrder []Person - err := aorm.Db(db).Debug(false).Driver(driver). + err := aorm.Db(db). Table(&person). WhereEq(&person.Type, 0). OrderBy(&person.Age, builder.Desc). GetMany(&listByOrder) if err != nil { - panic(driver + "testOrderBy" + "found err") + panic(db.DriverName() + "testOrderBy" + "found err") } var listByOrder2 []Person - err2 := aorm.Db(db).Debug(false).Driver(driver). + err2 := aorm.Db(db). Table(&person, "o"). WhereEq(&person.Type, 0, "o"). OrderBy(&person.Age, builder.Desc, "o"). GetMany(&listByOrder2) if err2 != nil { - panic(driver + "testOrderBy" + "found err") + panic(db.DriverName() + "testOrderBy" + "found err") } } -func testLimit(driver string, db *sql.DB) { +func testLimit(db *model.AormDB) { var list3 []Person - err1 := aorm.Db(db).Debug(false). + err1 := aorm.Db(db). Table(&person). WhereEq(&person.Type, 0). Limit(50, 10). - Driver(driver). OrderBy(&person.Id, builder.Desc). GetMany(&list3) if err1 != nil { - panic(driver + "testLimit" + "found err") + panic(db.DriverName() + "testLimit" + "found err") } var list4 []Person - err := aorm.Db(db).Debug(false). - Driver(driver). + err := aorm.Db(db). Table(&person). WhereEq(&person.Type, 0). Page(3, 10). OrderBy(&person.Id, builder.Desc). GetMany(&list4) if err != nil { - panic(driver + "testPage" + "found err") + panic(db.DriverName() + "testPage" + "found err") } } -func testLock(driver string, db *sql.DB, id int64) { - if driver == model.Sqlite3 || driver == model.Mssql { +func testLock(db *model.AormDB, id int64) { + if db.DriverName() == model.Sqlite3 || db.DriverName() == model.Mssql { return } var itemByLock Person err := aorm.Db(db). - Debug(false). LockForUpdate(true). Table(&person). WhereEq(&person.Id, id). - Driver(driver). OrderBy(&person.Id, builder.Desc). GetOne(&itemByLock) if err != nil { - panic(driver + "testLock" + "found err") + panic(db.DriverName() + "testLock" + "found err") } } -func testIncrement(driver string, db *sql.DB, id int64) { - _, err := aorm.Db(db).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).Increment(&person.Age, 1) +func testIncrement(db *model.AormDB, id int64) { + _, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Increment(&person.Age, 1) if err != nil { - panic(driver + " testIncrement " + "found err:" + err.Error()) + panic(db.DriverName() + " testIncrement " + "found err:" + err.Error()) } } -func testDecrement(driver string, db *sql.DB, id int64) { - _, err := aorm.Db(db).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).Decrement(&person.Age, 2) +func testDecrement(db *model.AormDB, id int64) { + _, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Decrement(&person.Age, 2) if err != nil { - panic(driver + "testDecrement" + "found err") + panic(db.DriverName() + "testDecrement" + "found err") } } -func testValue(driver string, db *sql.DB, id int64) { +func testValue(db *model.AormDB, id int64) { var name string - errName := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Name, &name) + errName := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Name, &name) if errName != nil { - panic(driver + "testValue" + "found err") + panic(db.DriverName() + "testValue" + "found err") } var age int64 - errAge := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Age, &age) + errAge := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Age, &age) if errAge != nil { - panic(driver + "testValue" + "found err") + panic(db.DriverName() + "testValue" + "found err") } var money float32 - errMoney := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Money, &money) + errMoney := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Money, &money) if errMoney != nil { - panic(driver + "testValue" + "found err") + panic(db.DriverName() + "testValue" + "found err") } var test float64 - errTest := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Test, &test) + errTest := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Test, &test) if errTest != nil { - panic(driver + "testValue" + "found err") + panic(db.DriverName() + "testValue" + "found err") } } -func testPluck(driver string, db *sql.DB) { +func testPluck(db *model.AormDB) { var nameList []string - errNameList := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Name, &nameList) + errNameList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Name, &nameList) if errNameList != nil { - panic(driver + "testPluck" + "found err") + panic(db.DriverName() + "testPluck" + "found err") } var ageList []int64 - errAgeList := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Age, &ageList) + errAgeList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Age, &ageList) if errAgeList != nil { - panic(driver + "testPluck" + "found err:" + errAgeList.Error()) + panic(db.DriverName() + "testPluck" + "found err:" + errAgeList.Error()) } var moneyList []float32 - errMoneyList := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Money, &moneyList) + errMoneyList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Money, &moneyList) if errMoneyList != nil { - panic(driver + "testPluck" + "found err") + panic(db.DriverName() + "testPluck" + "found err") } var testList []float64 - errTestList := aorm.Db(db).Debug(false).Driver(driver).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Test, &testList) + errTestList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Test, &testList) if errTestList != nil { - panic(driver + "testPluck" + "found err") + panic(db.DriverName() + "testPluck" + "found err") } } -func testCount(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Table(&person).WhereEq(&person.Age, 18).Driver(driver).Count("*") +func testCount(db *model.AormDB) { + _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Count("*") if err != nil { - panic(driver + "testCount" + "found err") + panic(db.DriverName() + "testCount" + "found err") } } -func testSum(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Table(&person).WhereEq(&person.Age, 18).Driver(driver).Sum(&person.Age) +func testSum(db *model.AormDB) { + _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Sum(&person.Age) if err != nil { - panic(driver + "testSum" + "found err") + panic(db.DriverName() + "testSum" + "found err") } } -func testAvg(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Table(&person).WhereEq(&person.Age, 18).Driver(driver).Avg(&person.Age) +func testAvg(db *model.AormDB) { + _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Avg(&person.Age) if err != nil { - panic(driver + "testAvg" + "found err") + panic(db.DriverName() + "testAvg" + "found err") } } -func testMin(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Table(&person).WhereEq(&person.Age, 18).Driver(driver).Min(&person.Age) +func testMin(db *model.AormDB) { + _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Min(&person.Age) if err != nil { - panic(driver + "testMin" + "found err") + panic(db.DriverName() + "testMin" + "found err") } } -func testMax(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Table(&person).WhereEq(&person.Age, 18).Driver(driver).Max(&person.Age) +func testMax(db *model.AormDB) { + _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Max(&person.Age) if err != nil { - panic(driver + "testMax" + "found err") + panic(db.DriverName() + "testMax" + "found err") } } -func testDistinct(driver string, db *sql.DB) { +func testDistinct(db *model.AormDB) { var listByFiled []Person - err := aorm.Db(db).Debug(false).Driver(driver).Distinct(true).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) + err := aorm.Db(db).Distinct(true).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) if err != nil { - panic(driver + " testSelect " + "found err:" + err.Error()) + panic(db.DriverName() + " testSelect " + "found err:" + err.Error()) } } -func testRawSql(driver string, db *sql.DB, id2 int64) { +func testRawSql(db *model.AormDB, id2 int64) { var list []Person - err1 := aorm.Db(db).Debug(false).Driver(driver).RawSql("SELECT * FROM person WHERE id=? AND type=?", id2, 0).GetMany(&list) + err1 := aorm.Db(db).RawSql("SELECT * FROM person WHERE id=? AND type=?", id2, 0).GetMany(&list) if err1 != nil { panic(err1) } - _, err := aorm.Db(db).Debug(false).Driver(driver).RawSql("UPDATE person SET name = ? WHERE id=?", "Bob2", id2).Exec() + _, err := aorm.Db(db).RawSql("UPDATE person SET name = ? WHERE id=?", "Bob2", id2).Exec() if err != nil { - panic(driver + "testRawSql" + "found err") + panic(db.DriverName() + "testRawSql" + "found err") } } -func testTransaction(driver string, db *sql.DB) { - tx, _ := db.Begin() +func testTransaction(db *model.AormDB) { + tx := db.Begin() - id, errInsert := aorm.Db(tx).Debug(false).Driver(driver).Insert(&Person{ + id, errInsert := aorm.Db(tx).Insert(&Person{ Name: null.StringFrom("Alice"), }) if errInsert != nil { tx.Rollback() - panic(driver + " testTransaction " + "found err:" + errInsert.Error()) + panic(db.DriverName() + " testTransaction " + "found err:" + errInsert.Error()) return } - _, errCount := aorm.Db(tx).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).Count("*") + _, errCount := aorm.Db(tx).Table(&person).WhereEq(&person.Id, id).Count("*") if errCount != nil { tx.Rollback() - panic(driver + "testTransaction" + "found err") + panic(db.DriverName() + "testTransaction" + "found err") return } var personItem Person - errPerson := aorm.Db(tx).Debug(false).Driver(driver).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) + errPerson := aorm.Db(tx).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) if errPerson != nil { tx.Rollback() - panic(driver + "testTransaction" + "found err") + panic(db.DriverName() + "testTransaction" + "found err") return } - _, errUpdate := aorm.Db(tx).Debug(false).Driver(driver).Where(&Person{ + _, errUpdate := aorm.Db(tx).Where(&Person{ Id: null.IntFrom(id), }).Update(&Person{ Name: null.StringFrom("Bob"), @@ -721,24 +715,24 @@ func testTransaction(driver string, db *sql.DB) { if errUpdate != nil { tx.Rollback() - panic(driver + "testTransaction" + "found err") + panic(db.DriverName() + "testTransaction" + "found err") return } tx.Commit() } -func testTruncate(driver string, db *sql.DB) { - _, err := aorm.Db(db).Debug(false).Driver(driver).Table(&person).Truncate() +func testTruncate(db *model.AormDB) { + _, err := aorm.Db(db).Table(&person).Truncate() if err != nil { - panic(driver + " testTruncate " + "found err") + panic(db.DriverName() + " testTruncate " + "found err") } } func testPreview() { //Content Mysql - db, _ := sql.Open("mysql", "root:root@tcp(localhost:3306)/database_name?charset=utf8mb4&parseTime=True&loc=Local") + db, _ := aorm.Open("mysql", "root:root@tcp(localhost:3306)/database_name?charset=utf8mb4&parseTime=True&loc=Local") //Insert a Person personId, _ := aorm.Db(db).Insert(&Person{ @@ -760,7 +754,7 @@ func testPreview() { //GetOne var personItem Person - err := aorm.Db(db).Table(&person).Table(&person).WhereEq(&person.Id, personId).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) + err := aorm.Db(db).Table(&person).WhereEq(&person.Id, personId).OrderBy(&person.Id, builder.Desc).GetOne(&personItem) if err != nil { fmt.Println(err.Error()) } @@ -791,17 +785,19 @@ func testPreview() { } func testDbContent() { - dbContent, _ := aorm.Open("mysql", "root:root@tcp(localhost:3306)/database_name?charset=utf8mb4&parseTime=True&loc=Local") - fmt.Println(dbContent) + db, err := aorm.Open("mysql", "root:root@tcp(localhost:3306)/database_name?charset=utf8mb4&parseTime=True&loc=Local") + if err != nil { + panic(err) + } - //aorm.Db(dbContent).Insert(&Person{ - // Name: null.StringFrom("test name"), - //}) - // - //tx := dbContent.Begin() - //aorm.Db(tx).Insert(&Person{ - // Name: null.StringFrom("test name"), - //}) - // - //tx.Commit() + aorm.Db(db).Insert(&Person{ + Name: null.StringFrom("test name"), + }) + + tx := db.Begin() + aorm.Db(tx).Insert(&Person{ + Name: null.StringFrom("test name"), + }) + + tx.Commit() }