mirror of
https://github.com/tangpanqing/aorm.git
synced 2025-10-30 10:36:22 +08:00
support sqlite3
This commit is contained in:
131
builder/crud.go
131
builder/crud.go
@@ -101,50 +101,50 @@ func (ex *Builder) Insert(dest interface{}) (int64, error) {
|
||||
|
||||
sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")"
|
||||
|
||||
//如果是postgres,则转换?号到&1等
|
||||
if ex.driverName == "postgres" {
|
||||
sqlStr = coverSql(sqlStr)
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
//如果是mssql
|
||||
if ex.driverName == "mssql" {
|
||||
rows, err := ex.LinkCommon.Query(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var lastInsertId1 int64
|
||||
for rows.Next() {
|
||||
rows.Scan(&lastInsertId1)
|
||||
}
|
||||
return lastInsertId1, nil
|
||||
} else if ex.driverName == "postgres" {
|
||||
rows, err := ex.LinkCommon.Query(sqlStr+" returning id", paramList...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var lastInsertId1 int64
|
||||
for rows.Next() {
|
||||
rows.Scan(&lastInsertId1)
|
||||
}
|
||||
return lastInsertId1, nil
|
||||
if ex.driverName == model.Mssql {
|
||||
return ex.insertForMssqlOrPostgres(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...)
|
||||
} else if ex.driverName == model.Postgres {
|
||||
return ex.insertForMssqlOrPostgres(sqlStr+" returning id", paramList...)
|
||||
} else {
|
||||
res, err := ex.Exec(sqlStr, paramList...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
lastId, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return lastId, nil
|
||||
return ex.insertForCommon(sqlStr, paramList...)
|
||||
}
|
||||
}
|
||||
|
||||
func coverSql(sqlStr string) string {
|
||||
//对于Mssql,Postgres类型数据库,为了获取最后插入的id,需要改写入为查询
|
||||
func (ex *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) {
|
||||
rows, err := ex.LinkCommon.Query(sql, paramList...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var lastInsertId1 int64
|
||||
for rows.Next() {
|
||||
rows.Scan(&lastInsertId1)
|
||||
}
|
||||
return lastInsertId1, nil
|
||||
}
|
||||
|
||||
//对于非Mssql,Postgres类型数据库,可以直接获取最后插入的id
|
||||
func (ex *Builder) insertForCommon(sql string, paramList ...any) (int64, error) {
|
||||
res, err := ex.Exec(sql, paramList...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
lastId, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return lastId, nil
|
||||
}
|
||||
|
||||
//对于Postgres数据库,不支持?占位符,支持$1,$2类型,需要做转换
|
||||
func convertToPostgresSql(sqlStr string) string {
|
||||
t := 1
|
||||
for {
|
||||
if strings.Index(sqlStr, "?") == -1 {
|
||||
@@ -197,8 +197,8 @@ func (ex *Builder) InsertBatch(values interface{}) (int64, error) {
|
||||
|
||||
sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",")
|
||||
|
||||
if ex.driverName == "postgres" {
|
||||
sqlStr = coverSql(sqlStr)
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
res, err := ex.Exec(sqlStr, paramList...)
|
||||
@@ -325,9 +325,8 @@ func (ex *Builder) GetSqlAndParams() (string, []interface{}) {
|
||||
|
||||
sqlStr := "SELECT " + fieldStr + " FROM " + ex.tableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
|
||||
|
||||
//如果是postgres,则转换?号到&1等
|
||||
if ex.driverName == "postgres" {
|
||||
sqlStr = coverSql(sqlStr)
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
if ex.isDebug {
|
||||
@@ -345,9 +344,8 @@ func (ex *Builder) Update(dest interface{}) (int64, error) {
|
||||
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
|
||||
sqlStr := "UPDATE " + ex.tableName + setStr + whereStr
|
||||
|
||||
//如果是postgres,则转换?号到&1等
|
||||
if ex.driverName == "postgres" {
|
||||
sqlStr = coverSql(sqlStr)
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return ex.ExecAffected(sqlStr, paramList...)
|
||||
@@ -359,9 +357,8 @@ func (ex *Builder) Delete() (int64, error) {
|
||||
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
|
||||
sqlStr := "DELETE FROM " + ex.tableName + whereStr
|
||||
|
||||
//如果是postgres,则转换?号到&1等
|
||||
if ex.driverName == "postgres" {
|
||||
sqlStr = coverSql(sqlStr)
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return ex.ExecAffected(sqlStr, paramList...)
|
||||
@@ -370,7 +367,7 @@ func (ex *Builder) Delete() (int64, error) {
|
||||
// Truncate 清空记录, sqlte3不支持此操作
|
||||
func (ex *Builder) Truncate() (int64, error) {
|
||||
sqlStr := "TRUNCATE TABLE " + ex.tableName
|
||||
if ex.driverName == "sqlite3" {
|
||||
if ex.driverName == model.Sqlite3 {
|
||||
sqlStr = "DELETE FROM " + ex.tableName
|
||||
}
|
||||
|
||||
@@ -464,6 +461,10 @@ func (ex *Builder) Increment(fieldName string, step int) (int64, error) {
|
||||
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
|
||||
sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "+?" + whereStr
|
||||
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return ex.ExecAffected(sqlStr, paramList...)
|
||||
}
|
||||
|
||||
@@ -474,11 +475,19 @@ func (ex *Builder) Decrement(fieldName string, step int) (int64, error) {
|
||||
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
|
||||
sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "-?" + whereStr
|
||||
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return ex.ExecAffected(sqlStr, paramList...)
|
||||
}
|
||||
|
||||
// Exec 通用执行-新增,更新,删除
|
||||
func (ex *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
|
||||
if ex.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
if ex.isDebug {
|
||||
fmt.Println(sqlStr)
|
||||
fmt.Println(args...)
|
||||
@@ -574,14 +583,14 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
|
||||
}
|
||||
} else {
|
||||
if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le {
|
||||
if ex.driverName == "sqlite3" {
|
||||
if ex.driverName == model.Sqlite3 {
|
||||
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?")
|
||||
} else {
|
||||
switch where[i].Val.(type) {
|
||||
case float32:
|
||||
whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
|
||||
whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
|
||||
case float64:
|
||||
whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
|
||||
whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
|
||||
default:
|
||||
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?")
|
||||
}
|
||||
@@ -610,7 +619,7 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
|
||||
}
|
||||
}
|
||||
|
||||
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcat(valueStr...))
|
||||
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcatForLike(valueStr...))
|
||||
}
|
||||
|
||||
if where[i].Opt == In || where[i].Opt == NotIn {
|
||||
@@ -701,8 +710,18 @@ func getScans(columnNameList []string, fieldNameMap map[string]int, destValue re
|
||||
return scans
|
||||
}
|
||||
|
||||
func (ex *Builder) getConcat(vars ...string) string {
|
||||
if ex.driverName == "sqlite3" {
|
||||
func (ex *Builder) getConcatForFloat(vars ...string) string {
|
||||
if ex.driverName == model.Sqlite3 {
|
||||
return strings.Join(vars, "||")
|
||||
} else if ex.driverName == model.Postgres {
|
||||
return vars[0]
|
||||
} else {
|
||||
return "CONCAT(" + strings.Join(vars, ",") + ")"
|
||||
}
|
||||
}
|
||||
|
||||
func (ex *Builder) getConcatForLike(vars ...string) string {
|
||||
if ex.driverName == model.Sqlite3 || ex.driverName == model.Postgres {
|
||||
return strings.Join(vars, "||")
|
||||
} else {
|
||||
return "CONCAT(" + strings.Join(vars, ",") + ")"
|
||||
|
||||
Reference in New Issue
Block a user