diff --git a/builder/crud.go b/builder/crud.go index 8eb77e9..d9d1158 100644 --- a/builder/crud.go +++ b/builder/crud.go @@ -101,50 +101,50 @@ func (ex *Builder) Insert(dest interface{}) (int64, error) { sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")" - //如果是postgres,则转换?号到&1等 - if ex.driverName == "postgres" { - sqlStr = coverSql(sqlStr) + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) } - //如果是mssql - if ex.driverName == "mssql" { - rows, err := ex.LinkCommon.Query(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...) - if err != nil { - return 0, err - } - defer rows.Close() - var lastInsertId1 int64 - for rows.Next() { - rows.Scan(&lastInsertId1) - } - return lastInsertId1, nil - } else if ex.driverName == "postgres" { - rows, err := ex.LinkCommon.Query(sqlStr+" returning id", paramList...) - if err != nil { - return 0, err - } - defer rows.Close() - var lastInsertId1 int64 - for rows.Next() { - rows.Scan(&lastInsertId1) - } - return lastInsertId1, nil + if ex.driverName == model.Mssql { + return ex.insertForMssqlOrPostgres(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...) + } else if ex.driverName == model.Postgres { + return ex.insertForMssqlOrPostgres(sqlStr+" returning id", paramList...) } else { - res, err := ex.Exec(sqlStr, paramList...) - if err != nil { - return 0, err - } - - lastId, err := res.LastInsertId() - if err != nil { - return 0, err - } - - return lastId, nil + return ex.insertForCommon(sqlStr, paramList...) } } -func coverSql(sqlStr string) string { +//对于Mssql,Postgres类型数据库,为了获取最后插入的id,需要改写入为查询 +func (ex *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) { + rows, err := ex.LinkCommon.Query(sql, paramList...) + if err != nil { + return 0, err + } + defer rows.Close() + var lastInsertId1 int64 + for rows.Next() { + rows.Scan(&lastInsertId1) + } + return lastInsertId1, nil +} + +//对于非Mssql,Postgres类型数据库,可以直接获取最后插入的id +func (ex *Builder) insertForCommon(sql string, paramList ...any) (int64, error) { + res, err := ex.Exec(sql, paramList...) + if err != nil { + return 0, err + } + + lastId, err := res.LastInsertId() + if err != nil { + return 0, err + } + + return lastId, nil +} + +//对于Postgres数据库,不支持?占位符,支持$1,$2类型,需要做转换 +func convertToPostgresSql(sqlStr string) string { t := 1 for { if strings.Index(sqlStr, "?") == -1 { @@ -197,8 +197,8 @@ func (ex *Builder) InsertBatch(values interface{}) (int64, error) { sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",") - if ex.driverName == "postgres" { - sqlStr = coverSql(sqlStr) + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) } res, err := ex.Exec(sqlStr, paramList...) @@ -325,9 +325,8 @@ func (ex *Builder) GetSqlAndParams() (string, []interface{}) { sqlStr := "SELECT " + fieldStr + " FROM " + ex.tableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr - //如果是postgres,则转换?号到&1等 - if ex.driverName == "postgres" { - sqlStr = coverSql(sqlStr) + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) } if ex.isDebug { @@ -345,9 +344,8 @@ func (ex *Builder) Update(dest interface{}) (int64, error) { whereStr, paramList := ex.handleWhere(ex.whereList, paramList) sqlStr := "UPDATE " + ex.tableName + setStr + whereStr - //如果是postgres,则转换?号到&1等 - if ex.driverName == "postgres" { - sqlStr = coverSql(sqlStr) + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) } return ex.ExecAffected(sqlStr, paramList...) @@ -359,9 +357,8 @@ func (ex *Builder) Delete() (int64, error) { whereStr, paramList := ex.handleWhere(ex.whereList, paramList) sqlStr := "DELETE FROM " + ex.tableName + whereStr - //如果是postgres,则转换?号到&1等 - if ex.driverName == "postgres" { - sqlStr = coverSql(sqlStr) + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) } return ex.ExecAffected(sqlStr, paramList...) @@ -370,7 +367,7 @@ func (ex *Builder) Delete() (int64, error) { // Truncate 清空记录, sqlte3不支持此操作 func (ex *Builder) Truncate() (int64, error) { sqlStr := "TRUNCATE TABLE " + ex.tableName - if ex.driverName == "sqlite3" { + if ex.driverName == model.Sqlite3 { sqlStr = "DELETE FROM " + ex.tableName } @@ -464,6 +461,10 @@ func (ex *Builder) Increment(fieldName string, step int) (int64, error) { whereStr, paramList := ex.handleWhere(ex.whereList, paramList) sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "+?" + whereStr + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) + } + return ex.ExecAffected(sqlStr, paramList...) } @@ -474,11 +475,19 @@ func (ex *Builder) Decrement(fieldName string, step int) (int64, error) { whereStr, paramList := ex.handleWhere(ex.whereList, paramList) sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "-?" + whereStr + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) + } + return ex.ExecAffected(sqlStr, paramList...) } // Exec 通用执行-新增,更新,删除 func (ex *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { + if ex.driverName == model.Postgres { + sqlStr = convertToPostgresSql(sqlStr) + } + if ex.isDebug { fmt.Println(sqlStr) fmt.Println(args...) @@ -574,14 +583,14 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, } } else { if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le { - if ex.driverName == "sqlite3" { + if ex.driverName == model.Sqlite3 { whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?") } else { switch where[i].Val.(type) { case float32: - whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?") + whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?") case float64: - whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?") + whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?") default: whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?") } @@ -610,7 +619,7 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, } } - whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcat(valueStr...)) + whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcatForLike(valueStr...)) } if where[i].Opt == In || where[i].Opt == NotIn { @@ -701,8 +710,18 @@ func getScans(columnNameList []string, fieldNameMap map[string]int, destValue re return scans } -func (ex *Builder) getConcat(vars ...string) string { - if ex.driverName == "sqlite3" { +func (ex *Builder) getConcatForFloat(vars ...string) string { + if ex.driverName == model.Sqlite3 { + return strings.Join(vars, "||") + } else if ex.driverName == model.Postgres { + return vars[0] + } else { + return "CONCAT(" + strings.Join(vars, ",") + ")" + } +} + +func (ex *Builder) getConcatForLike(vars ...string) string { + if ex.driverName == model.Sqlite3 || ex.driverName == model.Postgres { return strings.Join(vars, "||") } else { return "CONCAT(" + strings.Join(vars, ",") + ")" diff --git a/builder/handle.go b/builder/handle.go index 0058b76..6d9ff38 100644 --- a/builder/handle.go +++ b/builder/handle.go @@ -2,6 +2,7 @@ package builder import ( "github.com/tangpanqing/aorm/helper" + "github.com/tangpanqing/aorm/model" "reflect" "strings" ) @@ -97,14 +98,14 @@ func handleOrder(orderList []string) string { return " Order BY " + strings.Join(orderList, ",") } -//拼接SQL,分页相关 +//拼接SQL,分页相关 Postgres数据库分页数量在前偏移在后,其他数据库偏移量在前分页数量在后,另外Mssql数据库的关键词是offset...next func (ex *Builder) handleLimit(offset int, pageSize int, paramList []any) (string, []any) { if 0 == pageSize { return "", paramList } str := "" - if ex.driverName == "postgres" { + if ex.driverName == model.Postgres { paramList = append(paramList, pageSize) paramList = append(paramList, offset) @@ -114,7 +115,7 @@ func (ex *Builder) handleLimit(offset int, pageSize int, paramList []any) (strin paramList = append(paramList, pageSize) str = " Limit ?,? " - if ex.driverName == "mssql" { + if ex.driverName == model.Mssql { str = " offset ? rows fetch next ? rows only " } } diff --git a/migrator/migrator.go b/migrator/migrator.go index ac5c778..9729c2c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -40,7 +40,7 @@ func (mi *Migrator) Opinion(key string, val string) *Migrator { //ShowCreateTable 获取创建表的ddl func (mi *Migrator) ShowCreateTable(tableName string) string { - if mi.driverName == "mysql" { + if mi.driverName == model.Mysql { me := migrate_mysql.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, @@ -69,7 +69,7 @@ func (mi *Migrator) Migrate(tableName string, dest interface{}) { } func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { - if mi.driverName == "mssql" { + if mi.driverName == model.Mssql { me := migrate_mssql.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, @@ -80,7 +80,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { me.MigrateCommon(tableName, typeOf) } - if mi.driverName == "mysql" { + if mi.driverName == model.Mysql { me := migrate_mysql.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, @@ -91,7 +91,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { me.MigrateCommon(tableName, typeOf) } - if mi.driverName == "sqlite3" { + if mi.driverName == model.Sqlite3 { me := migrate_sqlite3.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, @@ -102,7 +102,7 @@ func (mi *Migrator) migrateCommon(tableName string, typeOf reflect.Type) { me.MigrateCommon(tableName, typeOf) } - if mi.driverName == "postgres" { + if mi.driverName == model.Postgres { me := migrate_postgres.MigrateExecutor{ DriverName: mi.driverName, OpinionList: mi.opinionList, diff --git a/model/model.go b/model/model.go index 75d2b71..70b0bcf 100644 --- a/model/model.go +++ b/model/model.go @@ -13,3 +13,8 @@ type OpinionItem struct { Key string Val string } + +const Mysql = "mysql" +const Mssql = "mssql" +const Postgres = "postgres" +const Sqlite3 = "sqlite3" diff --git a/test/aorm_test.go b/test/aorm_test.go index 632fcd7..92ad06c 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -10,6 +10,7 @@ import ( "github.com/tangpanqing/aorm" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/helper" + "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" "testing" ) @@ -59,10 +60,10 @@ type PersonWithArticleCount struct { func TestAll(t *testing.T) { dbList := make([]aorm.DbContent, 0) - //dbList = append(dbList, testSqlite3Connect()) - //dbList = append(dbList, testMysqlConnect()) + dbList = append(dbList, testSqlite3Connect()) + dbList = append(dbList, testMysqlConnect()) dbList = append(dbList, testPostgresConnect()) - //dbList = append(dbList, testMssqlConnect()) + dbList = append(dbList, testMssqlConnect()) for i := 0; i < len(dbList); i++ { dbItem := dbList[i] @@ -108,7 +109,7 @@ func TestAll(t *testing.T) { testExec(dbItem.DriverName, dbItem.DbLink) testTransaction(dbItem.DriverName, dbItem.DbLink) - //testTruncate(dbItem.DriverName, dbItem.DbLink) + testTruncate(dbItem.DriverName, dbItem.DbLink) testHelper(dbItem.DriverName, dbItem.DbLink) } } @@ -133,6 +134,11 @@ func testMysqlConnect() aorm.DbContent { panic(mysqlErr) } + err := mysqlContent.DbLink.Ping() + if err != nil { + panic(err) + } + return mysqlContent } @@ -153,7 +159,7 @@ func testPostgresConnect() aorm.DbContent { } func testMssqlConnect() aorm.DbContent { - info := fmt.Sprintf("server=%s;database=%s;user id=%s;password=%s;port=%d", "localhost", "database_name", "sa", "root", 1433) + info := fmt.Sprintf("server=%s;database=%s;user id=%s;password=%s;port=%d;encrypt=disable", "localhost", "database_name", "sa", "root", 1433) mssqlContent, mssqlErr := aorm.Open("mssql", info) if mssqlErr != nil { panic(mssqlErr) @@ -167,20 +173,20 @@ func testMssqlConnect() aorm.DbContent { return mssqlContent } -func testMigrate(name string, db *sql.DB) { +func testMigrate(driver string, db *sql.DB) { //AutoMigrate - aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").AutoMigrate(&Person{}) - aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{}) + aorm.Migrator(db).Driver(driver).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").AutoMigrate(&Person{}) + aorm.Migrator(db).Driver(driver).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "文章").AutoMigrate(&Article{}) //Migrate - aorm.Migrator(db).Driver(name).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").Migrate("person_1", &Person{}) + aorm.Migrator(db).Driver(driver).Opinion("ENGINE", "InnoDB").Opinion("COMMENT", "人员表").Migrate("person_1", &Person{}) } -func testShowCreateTable(name string, db *sql.DB) { - aorm.Migrator(db).Driver(name).ShowCreateTable("person") +func testShowCreateTable(driver string, db *sql.DB) { + aorm.Migrator(db).Driver(driver).ShowCreateTable("person") } -func testInsert(name string, db *sql.DB) int64 { +func testInsert(driver string, db *sql.DB) int64 { obj := Person{ Name: null.StringFrom("Alice"), Sex: null.BoolFrom(false), @@ -191,50 +197,50 @@ func testInsert(name string, db *sql.DB) int64 { Test: null.FloatFrom(2), } - id, errInsert := aorm.Use(db).Debug(true).Driver(name).Insert(&obj) + id, errInsert := aorm.Use(db).Debug(false).Driver(driver).Insert(&obj) if errInsert != nil { - panic(name + " testInsert " + "found err: " + errInsert.Error()) + panic(driver + " testInsert " + "found err: " + errInsert.Error()) } - aorm.Use(db).Debug(false).Driver(name).Insert(&Article{ + aorm.Use(db).Debug(false).Driver(driver).Insert(&Article{ Type: null.IntFrom(0), PersonId: null.IntFrom(id), ArticleBody: null.StringFrom("文章内容"), }) var person Person - err := aorm.Use(db).Table("person").Debug(true).Driver(name).WhereEq("id", id).OrderBy("id", "DESC").GetOne(&person) + err := aorm.Use(db).Table("person").Debug(false).Driver(driver).WhereEq("id", id).OrderBy("id", "DESC").GetOne(&person) if err != nil { fmt.Println(err.Error()) } if obj.Name.String != person.Name.String { - fmt.Println("Name not match, expected: " + obj.Name.String + " ,but real is : " + person.Name.String) + fmt.Println(driver + ",Name not match, expected: " + obj.Name.String + " ,but real is : " + person.Name.String) } if obj.Sex.Bool != person.Sex.Bool { - fmt.Println("Sex not match, expected: " + fmt.Sprintf("%v", obj.Sex.Bool) + " ,but real is : " + fmt.Sprintf("%v", person.Sex.Bool)) + fmt.Println(driver + ",Sex not match, expected: " + fmt.Sprintf("%v", obj.Sex.Bool) + " ,but real is : " + fmt.Sprintf("%v", person.Sex.Bool)) } if obj.Age.Int64 != person.Age.Int64 { - fmt.Println("Age not match, expected: " + fmt.Sprintf("%v", obj.Age.Int64) + " ,but real is : " + fmt.Sprintf("%v", person.Age.Int64)) + fmt.Println(driver + ",Age not match, expected: " + fmt.Sprintf("%v", obj.Age.Int64) + " ,but real is : " + fmt.Sprintf("%v", person.Age.Int64)) } if obj.Type.Int64 != person.Type.Int64 { - fmt.Println("Type not match, expected: " + fmt.Sprintf("%v", obj.Type.Int64) + " ,but real is : " + fmt.Sprintf("%v", person.Type.Int64)) + fmt.Println(driver + ",Type not match, expected: " + fmt.Sprintf("%v", obj.Type.Int64) + " ,but real is : " + fmt.Sprintf("%v", person.Type.Int64)) } if obj.Money.Float64 != person.Money.Float64 { - fmt.Println(name + ",Money not match, expected: " + fmt.Sprintf("%v", obj.Money.Float64) + " ,but real is : " + fmt.Sprintf("%v", person.Money.Float64)) + fmt.Println(driver + ",Money not match, expected: " + fmt.Sprintf("%v", obj.Money.Float64) + " ,but real is : " + fmt.Sprintf("%v", person.Money.Float64)) } if obj.Test.Float64 != person.Test.Float64 { - fmt.Println(name + ",Test not match, expected: " + fmt.Sprintf("%v", obj.Test.Float64) + " ,but real is : " + fmt.Sprintf("%v", person.Test.Float64)) + fmt.Println(driver + ",Test not match, expected: " + fmt.Sprintf("%v", obj.Test.Float64) + " ,but real is : " + fmt.Sprintf("%v", person.Test.Float64)) } return id } -func testInsertBatch(name string, db *sql.DB) int64 { +func testInsertBatch(driver string, db *sql.DB) int64 { var batch []Person batch = append(batch, Person{ Name: null.StringFrom("Alice"), @@ -256,92 +262,92 @@ func testInsertBatch(name string, db *sql.DB) int64 { Test: null.FloatFrom(200.15987654321987654321), }) - count, err := aorm.Use(db).Debug(true).Driver(name).InsertBatch(&batch) + count, err := aorm.Use(db).Debug(false).Driver(driver).InsertBatch(&batch) if err != nil { - panic(name + " testInsertBatch " + "found err:" + err.Error()) + panic(driver + " testInsertBatch " + "found err:" + err.Error()) } return count } -func testGetOne(name string, db *sql.DB, id int64) { +func testGetOne(driver string, db *sql.DB, id int64) { var person Person - errFind := aorm.Use(db).Debug(false).Driver(name).OrderBy("id", "DESC").Where(&Person{Id: null.IntFrom(id)}).GetOne(&person) + errFind := aorm.Use(db).Debug(false).Driver(driver).OrderBy("id", "DESC").Where(&Person{Id: null.IntFrom(id)}).GetOne(&person) if errFind != nil { - panic(name + "testGetOne" + "found err") + panic(driver + "testGetOne" + "found err") } } -func testGetMany(name string, db *sql.DB) { +func testGetMany(driver string, db *sql.DB) { var list []Person - errSelect := aorm.Use(db).Driver(name).Debug(false).Where(&Person{Type: null.IntFrom(0)}).GetMany(&list) + errSelect := aorm.Use(db).Driver(driver).Debug(false).Where(&Person{Type: null.IntFrom(0)}).GetMany(&list) if errSelect != nil { - panic(name + " testGetMany " + "found err:" + errSelect.Error()) + panic(driver + " testGetMany " + "found err:" + errSelect.Error()) } } -func testUpdate(name string, db *sql.DB, id int64) { - _, errUpdate := aorm.Use(db).Debug(false).Driver(name).Where(&Person{Id: null.IntFrom(id)}).Update(&Person{Name: null.StringFrom("Bob")}) +func testUpdate(driver string, db *sql.DB, id int64) { + _, errUpdate := aorm.Use(db).Debug(false).Driver(driver).Where(&Person{Id: null.IntFrom(id)}).Update(&Person{Name: null.StringFrom("Bob")}) if errUpdate != nil { - panic(name + "testGetMany" + "found err") + panic(driver + "testGetMany" + "found err") } } -func testDelete(name string, db *sql.DB, id int64) { - _, errDelete := aorm.Use(db).Driver(name).Debug(false).Where(&Person{Id: null.IntFrom(id)}).Delete() +func testDelete(driver string, db *sql.DB, id int64) { + _, errDelete := aorm.Use(db).Driver(driver).Debug(false).Where(&Person{Id: null.IntFrom(id)}).Delete() if errDelete != nil { - panic(name + "testDelete" + "found err") + panic(driver + "testDelete" + "found err") } } -func testTable(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Driver(name).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) +func testTable(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Driver(driver).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) if err != nil { - panic(name + " testTable " + "found err:" + err.Error()) + panic(driver + " testTable " + "found err:" + err.Error()) } } -func testSelect(name string, db *sql.DB) { +func testSelect(driver string, db *sql.DB) { var listByFiled []Person - err := aorm.Use(db).Debug(false).Driver(name).Select("name,age").Where(&Person{Age: null.IntFrom(18)}).GetMany(&listByFiled) + err := aorm.Use(db).Debug(false).Driver(driver).Select("name,age").Where(&Person{Age: null.IntFrom(18)}).GetMany(&listByFiled) if err != nil { - panic(name + " testSelect " + "found err:" + err.Error()) + panic(driver + " testSelect " + "found err:" + err.Error()) } } -func testSelectWithSub(name string, db *sql.DB) { +func testSelectWithSub(driver string, db *sql.DB) { var listByFiled []PersonWithArticleCount sub := aorm.Sub().Table("article").SelectCount("id", "article_count_tem").WhereRaw("person_id", "=person.id") err := aorm.Use(db).Debug(false). - Driver(name). + Driver(driver). SelectExp(&sub, "article_count"). Select("*"). Where(&Person{Age: null.IntFrom(18)}). GetMany(&listByFiled) if err != nil { - panic(name + " testSelectWithSub " + "found err:" + err.Error()) + panic(driver + " testSelectWithSub " + "found err:" + err.Error()) } } -func testWhereWithSub(name string, db *sql.DB) { +func testWhereWithSub(driver string, db *sql.DB) { var listByFiled []Person sub := aorm.Sub().Table("article").Select("person_id").GroupBy("person_id").HavingGt("count(person_id)", 0) err := aorm.Use(db).Debug(false). Table("person"). - Driver(name). + Driver(driver). WhereIn("id", &sub). GetMany(&listByFiled) if err != nil { - panic(name + " testWhereWithSub " + "found err:" + err.Error()) + panic(driver + " testWhereWithSub " + "found err:" + err.Error()) } } -func testWhere(name string, db *sql.DB) { +func testWhere(driver string, db *sql.DB) { var listByWhere []Person var where1 []builder.WhereItem @@ -351,13 +357,13 @@ func testWhere(name string, db *sql.DB) { where1 = append(where1, builder.WhereItem{Field: "money", Opt: builder.Eq, Val: 100.15}) where1 = append(where1, builder.WhereItem{Field: "name", Opt: builder.Like, Val: []string{"%", "li", "%"}}) - err := aorm.Use(db).Debug(true).Driver(name).Table("person").WhereArr(where1).GetMany(&listByWhere) + err := aorm.Use(db).Debug(false).Driver(driver).Table("person").WhereArr(where1).GetMany(&listByWhere) if err != nil { - panic(name + "testWhere" + "found err") + panic(driver + "testWhere" + "found err") } } -func testJoin(name string, db *sql.DB) { +func testJoin(driver string, db *sql.DB) { var list2 []ArticleVO var where2 []builder.WhereItem where2 = append(where2, builder.WhereItem{Field: "o.type", Opt: builder.Eq, Val: 0}) @@ -368,13 +374,14 @@ func testJoin(name string, db *sql.DB) { Select("o.*"). Select("p.name as person_name"). WhereArr(where2). + Driver(driver). GetMany(&list2) if err != nil { - panic(name + " testWhere " + "found err " + err.Error()) + panic(driver + " testWhere " + "found err " + err.Error()) } } -func testGroupBy(name string, db *sql.DB) { +func testGroupBy(driver string, db *sql.DB) { var personAge PersonAge var where []builder.WhereItem where = append(where, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) @@ -384,15 +391,15 @@ func testGroupBy(name string, db *sql.DB) { Select("count(age) as age_count"). GroupBy("age"). WhereArr(where). - Driver(name). + Driver(driver). OrderBy("age", "DESC"). GetOne(&personAge) if err != nil { - panic(name + "testGroupBy" + "found err") + panic(driver + "testGroupBy" + "found err") } } -func testHaving(name string, db *sql.DB) { +func testHaving(driver string, db *sql.DB) { var listByHaving []PersonAge var where3 []builder.WhereItem @@ -407,16 +414,16 @@ func testHaving(name string, db *sql.DB) { Select("count(age) as age_count"). GroupBy("age"). WhereArr(where3). - Driver(name). + Driver(driver). OrderBy("age", "DESC"). HavingArr(having). GetMany(&listByHaving) if err != nil { - panic(name + " testHaving " + "found err") + panic(driver + " testHaving " + "found err") } } -func testOrderBy(name string, db *sql.DB) { +func testOrderBy(driver string, db *sql.DB) { var listByOrder []Person var where []builder.WhereItem where = append(where, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) @@ -424,13 +431,14 @@ func testOrderBy(name string, db *sql.DB) { Table("person"). WhereArr(where). OrderBy("age", builder.Desc). + Driver(driver). GetMany(&listByOrder) if err != nil { - panic(name + "testOrderBy" + "found err") + panic(driver + "testOrderBy" + "found err") } } -func testLimit(name string, db *sql.DB) { +func testLimit(driver string, db *sql.DB) { var list3 []Person var where1 []builder.WhereItem where1 = append(where1, builder.WhereItem{Field: "type", Opt: builder.Eq, Val: 0}) @@ -438,11 +446,11 @@ func testLimit(name string, db *sql.DB) { Table("person"). WhereArr(where1). Limit(50, 10). - Driver(name). + Driver(driver). OrderBy("id", "DESC"). GetMany(&list3) if err1 != nil { - panic(name + "testLimit" + "found err") + panic(driver + "testLimit" + "found err") } var list4 []Person @@ -452,16 +460,16 @@ func testLimit(name string, db *sql.DB) { Table("person"). WhereArr(where2). Page(3, 10). - Driver(name). + Driver(driver). OrderBy("id", "DESC"). GetMany(&list4) if err != nil { - panic(name + "testPage" + "found err") + panic(driver + "testPage" + "found err") } } -func testLock(name string, db *sql.DB, id int64) { - if name == "sqlite3" || name == "mssql" { +func testLock(driver string, db *sql.DB, id int64) { + if driver == model.Sqlite3 || driver == model.Mssql { return } @@ -470,25 +478,25 @@ func testLock(name string, db *sql.DB, id int64) { Debug(false). LockForUpdate(true). Where(&Person{Id: null.IntFrom(id)}). - Driver(name). + Driver(driver). OrderBy("id", "DESC"). GetOne(&itemByLock) if err != nil { - panic(name + "testLock" + "found err") + panic(driver + "testLock" + "found err") } } -func testIncrement(name string, db *sql.DB, id int64) { - _, err := aorm.Use(db).Debug(false).Where(&Person{Id: null.IntFrom(id)}).Increment("age", 1) +func testIncrement(driver string, db *sql.DB, id int64) { + _, err := aorm.Use(db).Debug(false).Where(&Person{Id: null.IntFrom(id)}).Driver(driver).Increment("age", 1) if err != nil { - panic(name + "testIncrement" + "found err") + panic(driver + " testIncrement " + "found err:" + err.Error()) } } -func testDecrement(name string, db *sql.DB, id int64) { - _, err := aorm.Use(db).Debug(false).Where(&Person{Id: null.IntFrom(id)}).Decrement("age", 2) +func testDecrement(driver string, db *sql.DB, id int64) { + _, err := aorm.Use(db).Debug(false).Where(&Person{Id: null.IntFrom(id)}).Driver(driver).Decrement("age", 2) if err != nil { - panic(name + "testDecrement" + "found err") + panic(driver + "testDecrement" + "found err") } } @@ -519,108 +527,108 @@ func testValue(driver string, db *sql.DB, id int64) { } } -func testPluck(name string, db *sql.DB) { +func testPluck(driver string, db *sql.DB) { var nameList []string - errNameList := aorm.Use(db).Debug(false).Driver(name).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("name", &nameList) + errNameList := aorm.Use(db).Debug(false).Driver(driver).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("name", &nameList) if errNameList != nil { - panic(name + "testPluck" + "found err") + panic(driver + "testPluck" + "found err") } var ageList []int64 - errAgeList := aorm.Use(db).Debug(false).Driver(name).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("age", &ageList) + errAgeList := aorm.Use(db).Debug(false).Driver(driver).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("age", &ageList) if errAgeList != nil { - panic(name + "testPluck" + "found err:" + errAgeList.Error()) + panic(driver + "testPluck" + "found err:" + errAgeList.Error()) } var moneyList []float32 - errMoneyList := aorm.Use(db).Debug(false).Driver(name).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("money", &moneyList) + errMoneyList := aorm.Use(db).Debug(false).Driver(driver).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("money", &moneyList) if errMoneyList != nil { - panic(name + "testPluck" + "found err") + panic(driver + "testPluck" + "found err") } var testList []float64 - errTestList := aorm.Use(db).Debug(false).Driver(name).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("test", &testList) + errTestList := aorm.Use(db).Debug(false).Driver(driver).OrderBy("id", "DESC").Where(&Person{Type: null.IntFrom(0)}).Limit(0, 3).Pluck("test", &testList) if errTestList != nil { - panic(name + "testPluck" + "found err") + panic(driver + "testPluck" + "found err") } } -func testCount(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Count("*") +func testCount(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Driver(driver).Count("*") if err != nil { - panic(name + "testCount" + "found err") + panic(driver + "testCount" + "found err") } } -func testSum(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Sum("age") +func testSum(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Driver(driver).Sum("age") if err != nil { - panic(name + "testSum" + "found err") + panic(driver + "testSum" + "found err") } } -func testAvg(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Avg("age") +func testAvg(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Driver(driver).Avg("age") if err != nil { - panic(name + "testAvg" + "found err") + panic(driver + "testAvg" + "found err") } } -func testMin(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Min("age") +func testMin(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Driver(driver).Min("age") if err != nil { - panic(name + "testMin" + "found err") + panic(driver + "testMin" + "found err") } } -func testMax(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Max("age") +func testMax(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Where(&Person{Age: null.IntFrom(18)}).Driver(driver).Max("age") if err != nil { - panic(name + "testMax" + "found err") + panic(driver + "testMax" + "found err") } } -func testExec(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Exec("UPDATE person SET name = ? WHERE id=?", "Bob", 3) +func testExec(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Driver(driver).Exec("UPDATE person SET name = ? WHERE id=?", "Bob", 3) if err != nil { - panic(name + "testExec" + "found err") + panic(driver + "testExec" + "found err") } } -func testTransaction(name string, db *sql.DB) { +func testTransaction(driver string, db *sql.DB) { tx, _ := db.Begin() - id, errInsert := aorm.Use(tx).Debug(false).Driver(name).Insert(&Person{ + id, errInsert := aorm.Use(tx).Debug(false).Driver(driver).Insert(&Person{ Name: null.StringFrom("Alice"), }) if errInsert != nil { tx.Rollback() - panic(name + " testTransaction " + "found err:" + errInsert.Error()) + panic(driver + " testTransaction " + "found err:" + errInsert.Error()) return } - _, errCount := aorm.Use(tx).Debug(false).Where(&Person{ + _, errCount := aorm.Use(tx).Debug(false).Driver(driver).Where(&Person{ Id: null.IntFrom(id), }).Count("*") if errCount != nil { tx.Rollback() - panic(name + "testTransaction" + "found err") + panic(driver + "testTransaction" + "found err") return } var person Person errPerson := aorm.Use(tx).Debug(false).Where(&Person{ Id: null.IntFrom(id), - }).Driver(name).OrderBy("id", "DESC").GetOne(&person) + }).Driver(driver).OrderBy("id", "DESC").GetOne(&person) if errPerson != nil { tx.Rollback() - panic(name + "testTransaction" + "found err") + panic(driver + "testTransaction" + "found err") return } - _, errUpdate := aorm.Use(tx).Debug(false).Where(&Person{ + _, errUpdate := aorm.Use(tx).Debug(false).Driver(driver).Where(&Person{ Id: null.IntFrom(id), }).Update(&Person{ Name: null.StringFrom("Bob"), @@ -628,21 +636,21 @@ func testTransaction(name string, db *sql.DB) { if errUpdate != nil { tx.Rollback() - panic(name + "testTransaction" + "found err") + panic(driver + "testTransaction" + "found err") return } tx.Commit() } -func testTruncate(name string, db *sql.DB) { - _, err := aorm.Use(db).Debug(false).Driver(name).Table("person").Truncate() +func testTruncate(driver string, db *sql.DB) { + _, err := aorm.Use(db).Debug(false).Driver(driver).Table("person").Truncate() if err != nil { - panic(name + " testTruncate " + "found err") + panic(driver + " testTruncate " + "found err") } } -func testHelper(name string, db *sql.DB) { +func testHelper(driver string, db *sql.DB) { var list2 []ArticleVO var where2 []builder.WhereItem where2 = append(where2, builder.WhereItem{Field: "o.type", Opt: builder.Eq, Val: 0}) @@ -653,8 +661,9 @@ func testHelper(name string, db *sql.DB) { Select("o.*"). Select(helper.Ul("p.name as personName")). WhereArr(where2). + Driver(driver). GetMany(&list2) if err != nil { - panic(name + "testHelper" + "found err") + panic(driver + "testHelper" + "found err") } }