mirror of
https://github.com/tangpanqing/aorm.git
synced 2025-10-06 08:26:55 +08:00
new test
This commit is contained in:
107
builder/crud.go
107
builder/crud.go
@@ -94,14 +94,27 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
|
||||
typeOf := reflect.TypeOf(dest)
|
||||
valueOf := reflect.ValueOf(dest)
|
||||
|
||||
//主键名字
|
||||
var primaryKey = ""
|
||||
|
||||
var keys []string
|
||||
var paramList []any
|
||||
var place []string
|
||||
for i := 0; i < typeOf.Elem().NumField(); i++ {
|
||||
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
|
||||
|
||||
//如果是Postgres数据库,寻找主键
|
||||
if b.driverName == model.Postgres {
|
||||
tag := typeOf.Elem().Field(i).Tag.Get("aorm")
|
||||
if -1 != strings.Index(tag, "primary") {
|
||||
primaryKey = key
|
||||
}
|
||||
}
|
||||
|
||||
isNotNull := valueOf.Elem().Field(i).Field(0).Field(1).Bool()
|
||||
if isNotNull {
|
||||
key := helper.UnderLine(typeOf.Elem().Field(i).Name)
|
||||
val := valueOf.Elem().Field(i).Field(0).Field(0).Interface()
|
||||
|
||||
keys = append(keys, key)
|
||||
paramList = append(paramList, val)
|
||||
place = append(place, "?")
|
||||
@@ -110,14 +123,11 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
|
||||
|
||||
sqlStr := "INSERT INTO " + b.getTableNameCommon(typeOf, valueOf) + " (" + strings.Join(keys, ",") + ") VALUES (" + strings.Join(place, ",") + ")"
|
||||
|
||||
if b.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
if b.driverName == model.Mssql {
|
||||
return b.insertForMssqlOrPostgres(sqlStr+"; select ID = convert(bigint, SCOPE_IDENTITY())", paramList...)
|
||||
return b.insertForMssqlOrPostgres(sqlStr+"; SELECT SCOPE_IDENTITY()", paramList...)
|
||||
} else if b.driverName == model.Postgres {
|
||||
return b.insertForMssqlOrPostgres(sqlStr+" returning id", paramList...)
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
return b.insertForMssqlOrPostgres(sqlStr+" RETURNING "+primaryKey, paramList...)
|
||||
} else {
|
||||
return b.insertForCommon(sqlStr, paramList...)
|
||||
}
|
||||
@@ -125,6 +135,11 @@ func (b *Builder) Insert(dest interface{}) (int64, error) {
|
||||
|
||||
//对于Mssql,Postgres类型数据库,为了获取最后插入的id,需要改写入为查询
|
||||
func (b *Builder) insertForMssqlOrPostgres(sql string, paramList ...any) (int64, error) {
|
||||
if b.isDebug {
|
||||
fmt.Println(sql)
|
||||
fmt.Println(paramList...)
|
||||
}
|
||||
|
||||
rows, err := b.LinkCommon.Query(sql, paramList...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -324,21 +339,21 @@ func (b *Builder) GetSqlAndParams() (string, []interface{}) {
|
||||
groupStr, paramList := b.handleGroup(paramList)
|
||||
havingStr, paramList := b.handleHaving(paramList)
|
||||
orderStr, paramList := b.handleOrder(paramList)
|
||||
limitStr, paramList := b.handleLimit(b.offset, b.pageSize, paramList)
|
||||
lockStr := handleLockForUpdate(b.isLockForUpdate)
|
||||
limitStr, paramList := b.handleLimit(paramList)
|
||||
lockStr := b.handleLockForUpdate()
|
||||
|
||||
sqlStr := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
|
||||
sql := "SELECT " + fieldStr + " FROM " + tableName + " " + b.tableAlias + joinStr + whereStr + groupStr + havingStr + orderStr + limitStr + lockStr
|
||||
|
||||
if b.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
sql = convertToPostgresSql(sql)
|
||||
}
|
||||
|
||||
if b.isDebug {
|
||||
fmt.Println(sqlStr)
|
||||
//fmt.Println(paramList...)
|
||||
fmt.Println(sql)
|
||||
fmt.Println(paramList...)
|
||||
}
|
||||
|
||||
return sqlStr, paramList
|
||||
return sql, paramList
|
||||
}
|
||||
|
||||
// Update 更新记录
|
||||
@@ -351,10 +366,6 @@ func (b *Builder) Update(dest interface{}) (int64, error) {
|
||||
whereStr, paramList := b.handleWhere(paramList)
|
||||
sqlStr := "UPDATE " + b.getTableNameCommon(typeOf, valueOf) + setStr + whereStr
|
||||
|
||||
if b.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return b.ExecAffected(sqlStr, paramList...)
|
||||
}
|
||||
|
||||
@@ -378,18 +389,16 @@ func (b *Builder) Delete(destList ...interface{}) (int64, error) {
|
||||
whereStr, paramList := b.handleWhere(paramList)
|
||||
sqlStr := "DELETE FROM " + tableName + whereStr
|
||||
|
||||
if b.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return b.ExecAffected(sqlStr, paramList...)
|
||||
}
|
||||
|
||||
// Truncate 清空记录, sqlite3不支持此操作
|
||||
// Truncate 清空记录
|
||||
func (b *Builder) Truncate() (int64, error) {
|
||||
sqlStr := "TRUNCATE TABLE " + getTableNameByTable(b.table)
|
||||
sqlStr := ""
|
||||
if b.driverName == model.Sqlite3 {
|
||||
sqlStr = "DELETE FROM " + getTableNameByTable(b.table)
|
||||
} else {
|
||||
sqlStr = "TRUNCATE TABLE " + getTableNameByTable(b.table)
|
||||
}
|
||||
|
||||
return b.ExecAffected(sqlStr)
|
||||
@@ -399,7 +408,7 @@ func (b *Builder) Truncate() (int64, error) {
|
||||
func (b *Builder) Exists() (bool, error) {
|
||||
var obj IntStruct
|
||||
|
||||
err := b.selectCommon("", "1 as c", nil, "").Limit(0, 1).GetOne(&obj)
|
||||
err := b.selectCommon("", "1 AS c", nil, "").Limit(0, 1).GetOne(&obj)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -458,8 +467,9 @@ func (b *Builder) Value(field interface{}, dest interface{}) error {
|
||||
}
|
||||
|
||||
// Pluck 获取某一列的值
|
||||
func (b *Builder) Pluck(fieldName interface{}, values interface{}) error {
|
||||
b.Select(fieldName)
|
||||
func (b *Builder) Pluck(field interface{}, values interface{}) error {
|
||||
b.Select(field)
|
||||
fieldName := getFieldName(field)
|
||||
|
||||
rows, errRows := b.GetRows()
|
||||
defer rows.Close()
|
||||
@@ -506,10 +516,6 @@ func (b *Builder) Increment(field interface{}, step int) (int64, error) {
|
||||
whereStr, paramList := b.handleWhere(paramList)
|
||||
sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "+?" + whereStr
|
||||
|
||||
if b.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return b.ExecAffected(sqlStr, paramList...)
|
||||
}
|
||||
|
||||
@@ -520,31 +526,27 @@ func (b *Builder) Decrement(field interface{}, step int) (int64, error) {
|
||||
whereStr, paramList := b.handleWhere(paramList)
|
||||
sqlStr := "UPDATE " + getTableNameByTable(b.table) + " SET " + getFieldName(field) + "=" + getFieldName(field) + "-?" + whereStr
|
||||
|
||||
if b.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
}
|
||||
|
||||
return b.ExecAffected(sqlStr, paramList...)
|
||||
}
|
||||
|
||||
// Exec 通用执行-新增,更新,删除
|
||||
func (b *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
|
||||
func (b *Builder) Exec(sql string, paramList ...interface{}) (sql.Result, error) {
|
||||
if b.driverName == model.Postgres {
|
||||
sqlStr = convertToPostgresSql(sqlStr)
|
||||
sql = convertToPostgresSql(sql)
|
||||
}
|
||||
|
||||
if b.isDebug {
|
||||
fmt.Println(sqlStr)
|
||||
//fmt.Println(args...)
|
||||
fmt.Println(sql)
|
||||
fmt.Println(paramList...)
|
||||
}
|
||||
|
||||
smt, err1 := b.LinkCommon.Prepare(sqlStr)
|
||||
smt, err1 := b.LinkCommon.Prepare(sql)
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
defer smt.Close()
|
||||
|
||||
res, err2 := smt.Exec(args...)
|
||||
res, err2 := smt.Exec(paramList...)
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
@@ -554,8 +556,12 @@ func (b *Builder) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
// ExecAffected 通用执行-更新,删除
|
||||
func (b *Builder) ExecAffected(sqlStr string, args ...interface{}) (int64, error) {
|
||||
res, err := b.Exec(sqlStr, args...)
|
||||
func (b *Builder) ExecAffected(sql string, paramList ...interface{}) (int64, error) {
|
||||
if b.driverName == model.Postgres {
|
||||
sql = convertToPostgresSql(sql)
|
||||
}
|
||||
|
||||
res, err := b.Exec(sql, paramList...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -624,14 +630,25 @@ func (b *Builder) LockForUpdate(isLockForUpdate bool) *Builder {
|
||||
}
|
||||
|
||||
//拼接SQL,查询与筛选通用操作
|
||||
func (b *Builder) whereAndHaving(where []WhereItem, paramList []any) ([]string, []any) {
|
||||
func (b *Builder) whereAndHaving(where []WhereItem, paramList []any, isFromHaving bool) ([]string, []any) {
|
||||
var whereList []string
|
||||
for i := 0; i < len(where); i++ {
|
||||
allFieldName := ""
|
||||
if where[i].Prefix != "" {
|
||||
allFieldName += where[i].Prefix + "."
|
||||
}
|
||||
allFieldName += getFieldName(where[i].Field)
|
||||
|
||||
//如果是mssql或者Postgres,并且来自having的话,需要特殊处理
|
||||
if (b.driverName == model.Mssql || b.driverName == model.Postgres) && isFromHaving {
|
||||
fieldNameCurrent := getFieldName(where[i].Field)
|
||||
for m := 0; m < len(b.selectList); m++ {
|
||||
if fieldNameCurrent == getFieldName(b.selectList[m].FieldNew) {
|
||||
allFieldName += handleFieldWith(b.selectList[m])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allFieldName += getFieldName(where[i].Field)
|
||||
}
|
||||
|
||||
if "**builder.Builder" == reflect.TypeOf(where[i].Val).String() {
|
||||
executor := *(**Builder)(unsafe.Pointer(reflect.ValueOf(where[i].Val).Pointer()))
|
||||
|
Reference in New Issue
Block a user