support sqlite3

This commit is contained in:
tangpanqing
2022-12-26 11:15:21 +08:00
parent e9218503c5
commit bb9d0dfa69
5 changed files with 216 additions and 182 deletions

View File

@@ -101,50 +101,50 @@ func (ex *Builder) Insert(dest interface{}) (int64, error) {
sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")"
//如果是postgres,则转换?号到&1等
if ex.driverName == "postgres" {
sqlStr = coverSql(sqlStr)
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
//如果是mssql
if ex.driverName == "mssql" {
rows, err := ex.LinkCommon.Query(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...)
if err != nil {
return 0, err
}
defer rows.Close()
var lastInsertId1 int64
for rows.Next() {
rows.Scan(&lastInsertId1)
}
return lastInsertId1, nil
} else if ex.driverName == "postgres" {
rows, err := ex.LinkCommon.Query(sqlStr+" returning id", paramList...)
if err != nil {
return 0, err
}
defer rows.Close()
var lastInsertId1 int64
for rows.Next() {
rows.Scan(&lastInsertId1)
}
return lastInsertId1, nil
if ex.driverName == model.Mssql {
return ex.insertForMssqlOrPostgres(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...)
} else if ex.driverName == model.Postgres {
return ex.insertForMssqlOrPostgres(sqlStr+" returning id", paramList...)
} else {
res, err := ex.Exec(sqlStr, paramList...)
if err != nil {
return 0, err
}
lastId, err := res.LastInsertId()
if err != nil {
return 0, err
}
return lastId, nil
return ex.insertForCommon(sqlStr, paramList...)
}
}
func coverSql(sqlStr string) string {
//对于Mssql,Postgres类型数据库为了获取最后插入的id需要改写入为查询
func (ex *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) {
rows, err := ex.LinkCommon.Query(sql, paramList...)
if err != nil {
return 0, err
}
defer rows.Close()
var lastInsertId1 int64
for rows.Next() {
rows.Scan(&lastInsertId1)
}
return lastInsertId1, nil
}
//对于非Mssql,Postgres类型数据库可以直接获取最后插入的id
func (ex *Builder) insertForCommon(sql string, paramList ...any) (int64, error) {
res, err := ex.Exec(sql, paramList...)
if err != nil {
return 0, err
}
lastId, err := res.LastInsertId()
if err != nil {
return 0, err
}
return lastId, nil
}
//对于Postgres数据库不支持?占位符,支持$1,$2类型需要做转换
func convertToPostgresSql(sqlStr string) string {
t := 1
for {
if strings.Index(sqlStr, "?") == -1 {
@@ -197,8 +197,8 @@ func (ex *Builder) InsertBatch(values interface{}) (int64, error) {
sqlStr := "INSERT INTO " + ex.tableName + " (" + strings.Join(keys, ",") + ") VALUES " + strings.Join(place, ",")
if ex.driverName == "postgres" {
sqlStr = coverSql(sqlStr)
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
res, err := ex.Exec(sqlStr, paramList...)
@@ -325,9 +325,8 @@ func (ex *Builder) GetSqlAndParams() (string, []interface{}) {
sqlStr := "SELECT " + fieldStr + " FROM " + ex.tableName + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
//如果是postgres,则转换?号到&1等
if ex.driverName == "postgres" {
sqlStr = coverSql(sqlStr)
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
if ex.isDebug {
@@ -345,9 +344,8 @@ func (ex *Builder) Update(dest interface{}) (int64, error) {
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
sqlStr := "UPDATE " + ex.tableName + setStr + whereStr
//如果是postgres,则转换?号到&1等
if ex.driverName == "postgres" {
sqlStr = coverSql(sqlStr)
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return ex.ExecAffected(sqlStr, paramList...)
@@ -359,9 +357,8 @@ func (ex *Builder) Delete() (int64, error) {
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
sqlStr := "DELETE FROM " + ex.tableName + whereStr
//如果是postgres,则转换?号到&1等
if ex.driverName == "postgres" {
sqlStr = coverSql(sqlStr)
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return ex.ExecAffected(sqlStr, paramList...)
@@ -370,7 +367,7 @@ func (ex *Builder) Delete() (int64, error) {
// Truncate 清空记录, sqlte3不支持此操作
func (ex *Builder) Truncate() (int64, error) {
sqlStr := "TRUNCATE TABLE " + ex.tableName
if ex.driverName == "sqlite3" {
if ex.driverName == model.Sqlite3 {
sqlStr = "DELETE FROM " + ex.tableName
}
@@ -464,6 +461,10 @@ func (ex *Builder) Increment(fieldName string, step int) (int64, error) {
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "+?" + whereStr
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return ex.ExecAffected(sqlStr, paramList...)
}
@@ -474,11 +475,19 @@ func (ex *Builder) Decrement(fieldName string, step int) (int64, error) {
whereStr, paramList := ex.handleWhere(ex.whereList, paramList)
sqlStr := "UPDATE " + ex.tableName + " SET " + fieldName + "=" + fieldName + "-?" + whereStr
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return ex.ExecAffected(sqlStr, paramList...)
}
// Exec 通用执行-新增,更新,删除
func (ex *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
if ex.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
if ex.isDebug {
fmt.Println(sqlStr)
fmt.Println(args...)
@@ -574,14 +583,14 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
}
} else {
if where[i].Opt == Eq || where[i].Opt == Ne || where[i].Opt == Gt || where[i].Opt == Ge || where[i].Opt == Lt || where[i].Opt == Le {
if ex.driverName == "sqlite3" {
if ex.driverName == model.Sqlite3 {
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?")
} else {
switch where[i].Val.(type) {
case float32:
whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
case float64:
whereList = append(whereList, ex.getConcat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
whereList = append(whereList, ex.getConcatForFloat(where[i].Field, "''")+" "+where[i].Opt+" "+"?")
default:
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+"?")
}
@@ -610,7 +619,7 @@ func (ex *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string,
}
}
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcat(valueStr...))
whereList = append(whereList, where[i].Field+" "+where[i].Opt+" "+ex.getConcatForLike(valueStr...))
}
if where[i].Opt == In || where[i].Opt == NotIn {
@@ -701,8 +710,18 @@ func getScans(columnNameList []string, fieldNameMap map[string]int, destValue re
return scans
}
func (ex *Builder) getConcat(vars ...string) string {
if ex.driverName == "sqlite3" {
func (ex *Builder) getConcatForFloat(vars ...string) string {
if ex.driverName == model.Sqlite3 {
return strings.Join(vars, "||")
} else if ex.driverName == model.Postgres {
return vars[0]
} else {
return "CONCAT(" + strings.Join(vars, ",") + ")"
}
}
func (ex *Builder) getConcatForLike(vars ...string) string {
if ex.driverName == model.Sqlite3 || ex.driverName == model.Postgres {
return strings.Join(vars, "||")
} else {
return "CONCAT(" + strings.Join(vars, ",") + ")"