This commit is contained in:
tangpanqing
2023-01-06 17:11:55 +08:00
parent 9f9a7c975f
commit e65a46e49c
3 changed files with 158 additions and 122 deletions

View File

@@ -94,14 +94,27 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
typeOf := reflect.TypeOf(dest)
valueOf := reflect.ValueOf(dest)
//主键名字
var primaryKey = ""
var keys []string
var paramList []any
var place []string
for i := 0; i < typeOf.Elem().NumField(); i++ {
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
//如果是Postgres数据库寻找主键
if b.driverName == model.Postgres {
tag := typeOf.Elem().Field(i).Tag.Get("aorm")
if -1 != strings.Index(tag, "primary") {
primaryKey = key
}
}
isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool()
if isNotNull {
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
val := valueOf.Elem().Field(i).Field(0).Field(0).Interface()
keys = append(keys, key)
paramList = append(paramList, val)
place = append(place, "?")
@@ -110,14 +123,11 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf) + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")"
if b.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
if b.driverName == model.Mssql {
return b.insertForMssqlOrPostgres(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...)
return b.insertForMssqlOrPostgres(sqlStr+"; SELECT SCOPE_IDENTITY()", paramList...)
} else if b.driverName == model.Postgres {
return b.insertForMssqlOrPostgres(sqlStr+" returning id", paramList...)
sqlStr = convertToPostgresSql(sqlStr)
return b.insertForMssqlOrPostgres(sqlStr+" RETURNING "+primaryKey, paramList...)
} else {
return b.insertForCommon(sqlStr, paramList...)
}
@@ -125,6 +135,11 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
//对于Mssql,Postgres类型数据库为了获取最后插入的id需要改写入为查询
func (b *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) {
if b.isDebug {
fmt.Println(sql)
fmt.Println(paramList...)
}
rows, err := b.LinkCommon.Query(sql, paramList...)
if err != nil {
return 0, err
@@ -324,21 +339,21 @@ func (b *Builder) GetSqlAndParams() (string, []interface{}) {
groupStr, paramList := b.handleGroup(paramList)
havingStr, paramList := b.handleHaving(paramList)
orderStr, paramList := b.handleOrder(paramList)
limitStr, paramList := b.handleLimit(b.offset, b.pageSize, paramList)
lockStr := handleLockForUpdate(b.isLockForUpdate)
limitStr, paramList := b.handleLimit(paramList)
lockStr := b.handleLockForUpdate()
sqlStr := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
sql := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
if b.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
sql = convertToPostgresSql(sql)
}
if b.isDebug {
fmt.Println(sqlStr)
//fmt.Println(paramList...)
fmt.Println(sql)
fmt.Println(paramList...)
}
return sqlStr, paramList
return sql, paramList
}
// Update 更新记录
@@ -351,10 +366,6 @@ func (b *Builder) Update(dest interface{}) (int64, error) {
whereStr, paramList := b.handleWhere(paramList)
sqlStr := "UPDATE " + b.getTableNameCommon(typeOf, valueOf) + setStr + whereStr
if b.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return b.ExecAffected(sqlStr, paramList...)
}
@@ -378,18 +389,16 @@ func (b *Builder) Delete(destList ...interface{}) (int64, error) {
whereStr, paramList := b.handleWhere(paramList)
sqlStr := "DELETE FROM " + tableName + whereStr
if b.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return b.ExecAffected(sqlStr, paramList...)
}
// Truncate 清空记录, sqlite3不支持此操作
// Truncate 清空记录
func (b *Builder) Truncate() (int64, error) {
sqlStr := "TRUNCATE TABLE " + getTableNameByTable(b.table)
sqlStr := ""
if b.driverName == model.Sqlite3 {
sqlStr = "DELETE FROM " + getTableNameByTable(b.table)
} else {
sqlStr = "TRUNCATE TABLE " + getTableNameByTable(b.table)
}
return b.ExecAffected(sqlStr)
@@ -399,7 +408,7 @@ func (b *Builder) Truncate() (int64, error) {
func (b *Builder) Exists() (bool, error) {
var obj IntStruct
err := b.selectCommon("", "1 as c", nil, "").Limit(0, 1).GetOne(&obj)
err := b.selectCommon("", "1 AS c", nil, "").Limit(0, 1).GetOne(&obj)
if err != nil {
return false, err
}
@@ -458,8 +467,9 @@ func (b *Builder) Value(field interface{}, dest interface{}) error {
}
// Pluck 获取某一列的值
func (b *Builder) Pluck(fieldName interface{}, values interface{}) error {
b.Select(fieldName)
func (b *Builder) Pluck(field interface{}, values interface{}) error {
b.Select(field)
fieldName := getFieldName(field)
rows, errRows := b.GetRows()
defer rows.Close()
@@ -506,10 +516,6 @@ func (b *Builder) Increment(field interface{}, step int) (int64, error) {
whereStr, paramList := b.handleWhere(paramList)
sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "+?" + whereStr
if b.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return b.ExecAffected(sqlStr, paramList...)
}
@@ -520,31 +526,27 @@ func (b *Builder) Decrement(field interface{}, step int) (int64, error) {
whereStr, paramList := b.handleWhere(paramList)
sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "-?" + whereStr
if b.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
}
return b.ExecAffected(sqlStr, paramList...)
}
// Exec 通用执行-新增,更新,删除
func (b *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
func (b *Builder) Exec(sql string, paramList ...interface{}) (sql.Result, error) {
if b.driverName == model.Postgres {
sqlStr = convertToPostgresSql(sqlStr)
sql = convertToPostgresSql(sql)
}
if b.isDebug {
fmt.Println(sqlStr)
//fmt.Println(args...)
fmt.Println(sql)
fmt.Println(paramList...)
}
smt, err1 := b.LinkCommon.Prepare(sqlStr)
smt, err1 := b.LinkCommon.Prepare(sql)
if err1 != nil {
return nil, err1
}
defer smt.Close()
res, err2 := smt.Exec(args...)
res, err2 := smt.Exec(paramList...)
if err2 != nil {
return nil, err2
}
@@ -554,8 +556,12 @@ func (b *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
}
// ExecAffected 通用执行-更新,删除
func (b *Builder) ExecAffected(sqlStr string, args ...interface{}) (int64, error) {
res, err := b.Exec(sqlStr, args...)
func (b *Builder) ExecAffected(sql string, paramList ...interface{}) (int64, error) {
if b.driverName == model.Postgres {
sql = convertToPostgresSql(sql)
}
res, err := b.Exec(sql, paramList...)
if err != nil {
return 0, err
}
@@ -624,14 +630,25 @@ func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder {
}
//拼接SQL,查询与筛选通用操作
func (b *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) {
func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHaving bool) ([]string, []any) {
var whereList []string
for i := 0; i < len(where); i++ {
allFieldName := ""
if where[i].Prefix != "" {
allFieldName += where[i].Prefix + "."
}
allFieldName += getFieldName(where[i].Field)
//如果是mssql或者Postgres,并且来自having的话需要特殊处理
if (b.driverName == model.Mssql || b.driverName == model.Postgres) && isFromHaving {
fieldNameCurrent := getFieldName(where[i].Field)
for m := 0; m < len(b.selectList); m++ {
if fieldNameCurrent == getFieldName(b.selectList[m].FieldNew) {
allFieldName += handleFieldWith(b.selectList[m])
}
}
} else {
allFieldName += getFieldName(where[i].Field)
}
if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() {
executor := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer()))