diff --git a/aorm.go b/aorm.go index 6988674..7dce987 100644 --- a/aorm.go +++ b/aorm.go @@ -2,24 +2,24 @@ package aorm import ( "database/sql" //只需导入你需要的驱动即可 + "github.com/tangpanqing/aorm/base" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/migrator" - "github.com/tangpanqing/aorm/model" ) //Open 开始一个数据库连接 -func Open(driverName string, dataSourceName string) (*model.AormDB, error) { +func Open(driverName string, dataSourceName string) (*base.Db, error) { sqlDB, err := sql.Open(driverName, dataSourceName) if err != nil { - return &model.AormDB{}, err + return &base.Db{}, err } err2 := sqlDB.Ping() if err2 != nil { - return &model.AormDB{}, err2 + return &base.Db{}, err2 } - return &model.AormDB{ + return &base.Db{ Driver: driverName, SqlDB: sqlDB, }, nil @@ -30,7 +30,7 @@ func Store(destList ...interface{}) { } // Db 开始一个数据库操作 -func Db(link model.AormLink) *builder.Builder { +func Db(link base.Link) *builder.Builder { b := &builder.Builder{} b.Link = link @@ -40,7 +40,7 @@ func Db(link model.AormLink) *builder.Builder { } // Migrator 开始一个数据库迁移 -func Migrator(linkCommon model.AormLink) *migrator.Migrator { +func Migrator(linkCommon base.Link) *migrator.Migrator { mi := &migrator.Migrator{ Link: linkCommon, } diff --git a/base/Db.go b/base/Db.go new file mode 100644 index 0000000..b802efc --- /dev/null +++ b/base/Db.go @@ -0,0 +1,79 @@ +package base + +import ( + "database/sql" + "time" +) + +type Db struct { + Driver string + DebugMode bool + SqlDB *sql.DB +} + +//Close 关闭 +func (db *Db) Close() error { + return db.SqlDB.Close() +} + +//Begin 开始一个事务 +func (db *Db) Begin() *Tx { + SqlTx, _ := db.SqlDB.Begin() + + return &Tx{ + driver: db.Driver, + debugMode: db.DebugMode, + + sqlTx: SqlTx, + } +} + +//SetDebugMode 获取调试模式 +func (db *Db) SetDebugMode(debugMode bool) { + db.DebugMode = debugMode +} + +func (db *Db) SetConnMaxLifetime(d time.Duration) { + db.SqlDB.SetConnMaxLifetime(d) +} + +func (db *Db) SetConnMaxIdleTime(d time.Duration) { + db.SqlDB.SetConnMaxIdleTime(d) +} + +func (db *Db) SetMaxIdleConns(n int) { + db.SqlDB.SetMaxIdleConns(n) +} + +func (db *Db) SetMaxOpenConns(n int) { + db.SqlDB.SetMaxOpenConns(n) +} + +func (db *Db) Stats() sql.DBStats { + return db.SqlDB.Stats() +} + +//GetDebugMode 获取调试模式 +func (db *Db) GetDebugMode() bool { + return db.DebugMode +} + +func (db *Db) DriverName() string { + return db.Driver +} + +func (db *Db) Exec(query string, args ...interface{}) (sql.Result, error) { + return db.SqlDB.Exec(query, args...) +} + +func (db *Db) Prepare(query string) (*sql.Stmt, error) { + return db.SqlDB.Prepare(query) +} + +func (db *Db) Query(query string, args ...interface{}) (*sql.Rows, error) { + return db.SqlDB.Query(query, args...) +} + +func (db *Db) QueryRow(query string, args ...interface{}) *sql.Row { + return db.SqlDB.QueryRow(query, args...) +} diff --git a/model/AormLink.go b/base/Link.go similarity index 87% rename from model/AormLink.go rename to base/Link.go index 5729651..de1ba12 100644 --- a/model/AormLink.go +++ b/base/Link.go @@ -1,8 +1,8 @@ -package model +package base import "database/sql" -type AormLink interface { +type Link interface { GetDebugMode() bool DriverName() string Exec(query string, args ...interface{}) (sql.Result, error) diff --git a/base/Tx.go b/base/Tx.go new file mode 100644 index 0000000..f90274e --- /dev/null +++ b/base/Tx.go @@ -0,0 +1,42 @@ +package base + +import "database/sql" + +type Tx struct { + driver string + debugMode bool + sqlTx *sql.Tx +} + +//GetDebugMode 获取调试状态 +func (tx *Tx) GetDebugMode() bool { + return tx.debugMode +} + +func (tx *Tx) DriverName() string { + return tx.driver +} + +func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.sqlTx.Exec(query, args...) +} + +func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { + return tx.sqlTx.Prepare(query) +} + +func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { + return tx.sqlTx.Query(query, args...) +} + +func (tx *Tx) QueryRow(query string, args ...interface{}) *sql.Row { + return tx.sqlTx.QueryRow(query, args...) +} + +func (tx *Tx) Rollback() error { + return tx.sqlTx.Rollback() +} + +func (tx *Tx) Commit() error { + return tx.sqlTx.Commit() +} diff --git a/builder/crud.go b/builder/crud.go index 814b5dd..f667324 100644 --- a/builder/crud.go +++ b/builder/crud.go @@ -4,8 +4,8 @@ import ( "database/sql" "errors" "fmt" + "github.com/tangpanqing/aorm/base" "github.com/tangpanqing/aorm/driver" - "github.com/tangpanqing/aorm/model" "reflect" "strconv" "strings" @@ -34,7 +34,7 @@ const RawEq = "RawEq" // Builder 查询记录所需要的条件 type Builder struct { - Link model.AormLink + Link base.Link table interface{} tableAlias string diff --git a/migrate_postgres/migrate.go b/migrate_postgres/migrate.go index e322543..33964a2 100644 --- a/migrate_postgres/migrate.go +++ b/migrate_postgres/migrate.go @@ -256,7 +256,7 @@ func (mm *MigrateExecutor) modifyTable(tableFromCode Table, columnsFromCode []Co fmt.Println(columnCode.ColumnName.String, columnCode.DataType.String, columnDb.DataType.String) sql := "ALTER TABLE " + tableFromCode.TableName.String + " alter COLUMN " + getColumnStr(columnCode, "driver") - //fmt.Println(model) + //fmt.Println(base) _, err := mm.Builder.RawSql(sql).Exec() if err != nil { diff --git a/migrator/migrator.go b/migrator/migrator.go index 62b4bdb..4f92cb7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -1,13 +1,13 @@ package migrator import ( + "github.com/tangpanqing/aorm/base" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/driver" "github.com/tangpanqing/aorm/migrate_mssql" "github.com/tangpanqing/aorm/migrate_mysql" "github.com/tangpanqing/aorm/migrate_postgres" "github.com/tangpanqing/aorm/migrate_sqlite3" - "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/utils" "reflect" "strings" @@ -15,7 +15,7 @@ import ( type Migrator struct { //数据库操作连接 - Link model.AormLink + Link base.Link } //ShowCreateTable 获取创建表的ddl diff --git a/model/AormDB.go b/model/AormDB.go deleted file mode 100644 index dbd6b4b..0000000 --- a/model/AormDB.go +++ /dev/null @@ -1,80 +0,0 @@ -package model - -import ( - "database/sql" - "time" -) - -// AormDB 数据库连接与数据库类型 -type AormDB struct { - Driver string - DebugMode bool - SqlDB *sql.DB -} - -//Close 关闭 -func (db *AormDB) Close() error { - return db.SqlDB.Close() -} - -//Begin 开始一个事务 -func (db *AormDB) Begin() *AormTx { - SqlTx, _ := db.SqlDB.Begin() - - return &AormTx{ - driver: db.Driver, - debugMode: db.DebugMode, - - sqlTx: SqlTx, - } -} - -//SetDebugMode 获取调试模式 -func (db *AormDB) SetDebugMode(debugMode bool) { - db.DebugMode = debugMode -} - -func (db *AormDB) SetConnMaxLifetime(d time.Duration) { - db.SqlDB.SetConnMaxLifetime(d) -} - -func (db *AormDB) SetConnMaxIdleTime(d time.Duration) { - db.SqlDB.SetConnMaxIdleTime(d) -} - -func (db *AormDB) SetMaxIdleConns(n int) { - db.SqlDB.SetMaxIdleConns(n) -} - -func (db *AormDB) SetMaxOpenConns(n int) { - db.SqlDB.SetMaxOpenConns(n) -} - -func (db *AormDB) Stats() sql.DBStats { - return db.SqlDB.Stats() -} - -//GetDebugMode 获取调试模式 -func (db *AormDB) GetDebugMode() bool { - return db.DebugMode -} - -func (db *AormDB) DriverName() string { - return db.Driver -} - -func (db *AormDB) Exec(query string, args ...interface{}) (sql.Result, error) { - return db.SqlDB.Exec(query, args...) -} - -func (db *AormDB) Prepare(query string) (*sql.Stmt, error) { - return db.SqlDB.Prepare(query) -} - -func (db *AormDB) Query(query string, args ...interface{}) (*sql.Rows, error) { - return db.SqlDB.Query(query, args...) -} - -func (db *AormDB) QueryRow(query string, args ...interface{}) *sql.Row { - return db.SqlDB.QueryRow(query, args...) -} diff --git a/model/AormTx.go b/model/AormTx.go deleted file mode 100644 index b9dd9b6..0000000 --- a/model/AormTx.go +++ /dev/null @@ -1,42 +0,0 @@ -package model - -import "database/sql" - -type AormTx struct { - driver string - debugMode bool - sqlTx *sql.Tx -} - -//GetDebugMode 获取调试状态 -func (tx *AormTx) GetDebugMode() bool { - return tx.debugMode -} - -func (tx *AormTx) DriverName() string { - return tx.driver -} - -func (tx *AormTx) Exec(query string, args ...interface{}) (sql.Result, error) { - return tx.sqlTx.Exec(query, args...) -} - -func (tx *AormTx) Prepare(query string) (*sql.Stmt, error) { - return tx.sqlTx.Prepare(query) -} - -func (tx *AormTx) Query(query string, args ...interface{}) (*sql.Rows, error) { - return tx.sqlTx.Query(query, args...) -} - -func (tx *AormTx) QueryRow(query string, args ...interface{}) *sql.Row { - return tx.sqlTx.QueryRow(query, args...) -} - -func (tx *AormTx) Rollback() error { - return tx.sqlTx.Rollback() -} - -func (tx *AormTx) Commit() error { - return tx.sqlTx.Commit() -} diff --git a/test/aorm_test.go b/test/aorm_test.go index 8f30e0e..aad0e11 100644 --- a/test/aorm_test.go +++ b/test/aorm_test.go @@ -7,9 +7,9 @@ import ( _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/tangpanqing/aorm" + "github.com/tangpanqing/aorm/base" "github.com/tangpanqing/aorm/builder" "github.com/tangpanqing/aorm/driver" - "github.com/tangpanqing/aorm/model" "github.com/tangpanqing/aorm/null" "testing" "time" @@ -96,7 +96,7 @@ func TestAll(t *testing.T) { aorm.Store(&articleVO) aorm.Store(&personAge, &personWithArticleCount) - var dbList = []*model.AormDB{ + var dbList = []*base.Db{ testMysqlConnect(), testSqlite3Connect(), testPostgresConnect(), @@ -166,7 +166,7 @@ func TestAll(t *testing.T) { testDbContent() } -func testSqlite3Connect() *model.AormDB { +func testSqlite3Connect() *base.Db { sqlite3Content, sqlite3Err := aorm.Open(driver.Sqlite3, "test.db") if sqlite3Err != nil { panic(sqlite3Err) @@ -176,7 +176,7 @@ func testSqlite3Connect() *model.AormDB { return sqlite3Content } -func testMysqlConnect() *model.AormDB { +func testMysqlConnect() *base.Db { username := "root" password := "root" hostname := "localhost" @@ -192,7 +192,7 @@ func testMysqlConnect() *model.AormDB { return mysqlContent } -func testPostgresConnect() *model.AormDB { +func testPostgresConnect() *base.Db { psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", "localhost", 5432, "postgres", "root", "postgres") postgresContent, postgresErr := aorm.Open(driver.Postgres, psqlInfo) @@ -205,7 +205,7 @@ func testPostgresConnect() *model.AormDB { return postgresContent } -func testMssqlConnect() *model.AormDB { +func testMssqlConnect() *base.Db { mssqlInfo := fmt.Sprintf("server=%s;database=%s;user id=%s;password=%s;port=%d;encrypt=disable", "localhost", "database_name", "sa", "root", 1433) mssqlContent, mssqlErr := aorm.Open(driver.Mssql, mssqlInfo) @@ -217,17 +217,17 @@ func testMssqlConnect() *model.AormDB { return mssqlContent } -func testMigrate(db *model.AormDB) { +func testMigrate(db *base.Db) { aorm.Migrator(db).AutoMigrate(&person, &article, &student) aorm.Migrator(db).Migrate("person_1", &person) } -func testShowCreateTable(db *model.AormDB) { +func testShowCreateTable(db *base.Db) { aorm.Migrator(db).ShowCreateTable("person") } -func testInsert(db *model.AormDB) int64 { +func testInsert(db *base.Db) int64 { obj := Person{ Name: null.StringFrom("Alice"), Sex: null.BoolFrom(true), @@ -286,7 +286,7 @@ func testInsert(db *model.AormDB) int64 { return id } -func testInsertBatch(db *model.AormDB) int64 { +func testInsertBatch(db *base.Db) int64 { var batch []*Person batch = append(batch, &Person{ Name: null.StringFrom("Alice"), @@ -316,7 +316,7 @@ func testInsertBatch(db *model.AormDB) int64 { return count } -func testGetOne(db *model.AormDB, id int64) { +func testGetOne(db *base.Db, id int64) { var personItem Person errFind := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).GetOne(&personItem) if errFind != nil { @@ -324,7 +324,7 @@ func testGetOne(db *model.AormDB, id int64) { } } -func testGetMany(db *model.AormDB) { +func testGetMany(db *base.Db) { var list []Person errSelect := aorm.Db(db).Table(&person).WhereEq(&person.Type, 0).GetMany(&list) if errSelect != nil { @@ -332,14 +332,14 @@ func testGetMany(db *model.AormDB) { } } -func testUpdate(db *model.AormDB, id int64) { +func testUpdate(db *base.Db, id int64) { _, errUpdate := aorm.Db(db).WhereEq(&person.Id, id).Update(&Person{Name: null.StringFrom("Bob")}) if errUpdate != nil { panic(db.DriverName() + "testUpdate" + "found err") } } -func testDelete(db *model.AormDB, id int64) { +func testDelete(db *base.Db, id int64) { _, errDelete := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Delete() if errDelete != nil { panic(db.DriverName() + "testDelete" + "found err") @@ -353,7 +353,7 @@ func testDelete(db *model.AormDB, id int64) { } } -func testExists(db *model.AormDB, id int64) bool { +func testExists(db *base.Db, id int64) bool { exists, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).OrderBy(&person.Id, builder.Desc).Exists() if err != nil { panic(db.DriverName() + " testExists " + "found err:" + err.Error()) @@ -361,7 +361,7 @@ func testExists(db *model.AormDB, id int64) bool { return exists } -func testTable(db *model.AormDB) { +func testTable(db *base.Db) { _, err := aorm.Db(db).Table("person_1").Insert(&Person{Name: null.StringFrom("Cherry")}) if err != nil { panic(db.DriverName() + " testTable " + "found err:" + err.Error()) @@ -373,7 +373,7 @@ func testTable(db *model.AormDB) { } } -func testSelect(db *model.AormDB) { +func testSelect(db *base.Db) { var listByFiled []Person err := aorm.Db(db).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) if err != nil { @@ -381,7 +381,7 @@ func testSelect(db *model.AormDB) { } } -func testSelectWithSub(db *model.AormDB) { +func testSelectWithSub(db *base.Db) { var listByFiled []PersonWithArticleCount sub := aorm.Db(db).Table(&article).SelectCount(&article.Id, "article_count_tem").WhereRawEq(&article.PersonId, &person.Id) @@ -397,7 +397,7 @@ func testSelectWithSub(db *model.AormDB) { } } -func testWhereWithSub(db *model.AormDB) { +func testWhereWithSub(db *base.Db) { var listByFiled []Person sub := aorm.Db(db).Table(&article).SelectCount(&article.PersonId, "count_person_id").GroupBy(&article.PersonId).HavingGt("count_person_id", 0) err := aorm.Db(db). @@ -409,7 +409,7 @@ func testWhereWithSub(db *model.AormDB) { } } -func testWhere(db *model.AormDB) { +func testWhere(db *base.Db) { var listByWhere []Person err := aorm.Db(db).Table(&person).WhereArr([]builder.WhereItem{ builder.GenWhereItem(&person.Type, builder.Eq, 0), @@ -423,7 +423,7 @@ func testWhere(db *model.AormDB) { } } -func testJoin(db *model.AormDB) { +func testJoin(db *base.Db) { var list2 []ArticleVO err := aorm.Db(db). Table(&article). @@ -443,7 +443,7 @@ func testJoin(db *model.AormDB) { } } -func testJoinWithAlias(db *model.AormDB) { +func testJoinWithAlias(db *base.Db) { var list2 []ArticleVO err := aorm.Db(db). Table(&article, "o"). @@ -464,7 +464,7 @@ func testJoinWithAlias(db *model.AormDB) { } } -func testGroupBy(db *model.AormDB) { +func testGroupBy(db *base.Db) { var personAgeItem PersonAge err := aorm.Db(db). Table(&person). @@ -479,7 +479,7 @@ func testGroupBy(db *model.AormDB) { } } -func testHaving(db *model.AormDB) { +func testHaving(db *base.Db) { var listByHaving []PersonAge err := aorm.Db(db). @@ -496,7 +496,7 @@ func testHaving(db *model.AormDB) { } } -func testOrderBy(db *model.AormDB) { +func testOrderBy(db *base.Db) { var listByOrder []Person err := aorm.Db(db). Table(&person). @@ -518,7 +518,7 @@ func testOrderBy(db *model.AormDB) { } } -func testLimit(db *model.AormDB) { +func testLimit(db *base.Db) { var list3 []Person err1 := aorm.Db(db). Table(&person). @@ -542,7 +542,7 @@ func testLimit(db *model.AormDB) { } } -func testLock(db *model.AormDB, id int64) { +func testLock(db *base.Db, id int64) { if db.DriverName() == driver.Sqlite3 || db.DriverName() == driver.Mssql { return } @@ -559,21 +559,21 @@ func testLock(db *model.AormDB, id int64) { } } -func testIncrement(db *model.AormDB, id int64) { +func testIncrement(db *base.Db, id int64) { _, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Increment(&person.Age, 1) if err != nil { panic(db.DriverName() + " testIncrement " + "found err:" + err.Error()) } } -func testDecrement(db *model.AormDB, id int64) { +func testDecrement(db *base.Db, id int64) { _, err := aorm.Db(db).Table(&person).WhereEq(&person.Id, id).Decrement(&person.Age, 2) if err != nil { panic(db.DriverName() + "testDecrement" + "found err") } } -func testValue(db *model.AormDB, id int64) { +func testValue(db *base.Db, id int64) { var name string errName := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Id, id).Value(&person.Name, &name) @@ -600,7 +600,7 @@ func testValue(db *model.AormDB, id int64) { } } -func testPluck(db *model.AormDB) { +func testPluck(db *base.Db) { var nameList []string errNameList := aorm.Db(db).Table(&person).OrderBy(&person.Id, builder.Desc).WhereEq(&person.Type, 0).Limit(0, 3).Pluck(&person.Name, &nameList) if errNameList != nil { @@ -626,42 +626,42 @@ func testPluck(db *model.AormDB) { } } -func testCount(db *model.AormDB) { +func testCount(db *base.Db) { _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Count("*") if err != nil { panic(db.DriverName() + "testCount" + "found err") } } -func testSum(db *model.AormDB) { +func testSum(db *base.Db) { _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Sum(&person.Age) if err != nil { panic(db.DriverName() + "testSum" + "found err") } } -func testAvg(db *model.AormDB) { +func testAvg(db *base.Db) { _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Avg(&person.Age) if err != nil { panic(db.DriverName() + "testAvg" + "found err") } } -func testMin(db *model.AormDB) { +func testMin(db *base.Db) { _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Min(&person.Age) if err != nil { panic(db.DriverName() + "testMin" + "found err") } } -func testMax(db *model.AormDB) { +func testMax(db *base.Db) { _, err := aorm.Db(db).Table(&person).WhereEq(&person.Age, 18).Max(&person.Age) if err != nil { panic(db.DriverName() + "testMax" + "found err") } } -func testDistinct(db *model.AormDB) { +func testDistinct(db *base.Db) { var listByFiled []Person err := aorm.Db(db).Distinct(true).Table(&person).Select(&person.Name).Select(&person.Age).WhereEq(&person.Age, 18).GetMany(&listByFiled) if err != nil { @@ -669,7 +669,7 @@ func testDistinct(db *model.AormDB) { } } -func testRawSql(db *model.AormDB, id2 int64) { +func testRawSql(db *base.Db, id2 int64) { var list []Person err1 := aorm.Db(db).RawSql("SELECT * FROM person WHERE id=?", id2).GetMany(&list) if err1 != nil { @@ -682,7 +682,7 @@ func testRawSql(db *model.AormDB, id2 int64) { } } -func testTransaction(db *model.AormDB) { +func testTransaction(db *base.Db) { tx := db.Begin() id, errInsert := aorm.Db(tx).Insert(&Person{ @@ -725,7 +725,7 @@ func testTransaction(db *model.AormDB) { tx.Commit() } -func testTruncate(db *model.AormDB) { +func testTruncate(db *base.Db) { _, err := aorm.Db(db).Table(&person).Truncate() if err != nil { panic(db.DriverName() + " testTruncate " + "found err") @@ -809,7 +809,7 @@ func testDbContent() { tx.Commit() } -func closeAll(dbList []*model.AormDB) { +func closeAll(dbList []*base.Db) { for i := 0; i < len(dbList); i++ { dbList[i].Close() }