mirror of
https://github.com/gohouse/gorose.git
synced 2025-09-26 20:01:15 +08:00
561 lines
15 KiB
Go
561 lines
15 KiB
Go
package gorose
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"github.com/gohouse/gorose/v3/builder"
|
|
"github.com/gohouse/gorose/v3/driver"
|
|
"reflect"
|
|
"strings"
|
|
)
|
|
|
|
type Database struct {
|
|
*Engin
|
|
Driver *driver.Driver
|
|
Context *builder.Context
|
|
}
|
|
|
|
func NewDatabase(g *GoRose) *Database {
|
|
return &Database{
|
|
Driver: driver.NewDriver(g.driver),
|
|
Engin: NewEngin(g),
|
|
Context: builder.NewContext(g.prefix),
|
|
}
|
|
}
|
|
func (db *Database) Table(table any, alias ...string) *Database {
|
|
db.Context.TableClause.Table(table, alias...)
|
|
return db
|
|
}
|
|
|
|
// Select specifies the columns to retrieve.
|
|
// Select("a","b")
|
|
// Select("a.id as aid","b.id bid")
|
|
// Select("id,nickname name")
|
|
func (db *Database) Select(columns ...string) *Database {
|
|
db.Context.SelectClause.Select(columns...)
|
|
return db
|
|
}
|
|
|
|
// AddSelect 添加选择列
|
|
func (db *Database) AddSelect(columns ...string) *Database {
|
|
db.Context.SelectClause.AddSelect(columns...)
|
|
return db
|
|
}
|
|
|
|
// SelectRaw 允许直接在查询中插入原始SQL片段作为选择列。
|
|
func (db *Database) SelectRaw(raw string, binds ...any) *Database {
|
|
db.Context.SelectClause.SelectRaw(raw, binds...)
|
|
return db
|
|
}
|
|
|
|
// Join clause
|
|
func (db *Database) Join(table any, argOrFn ...any) *Database {
|
|
db.Context.JoinClause.Join(table, argOrFn...)
|
|
return db
|
|
}
|
|
func (db *Database) JoinOn(table any, fn func(on builder.IJoinOn)) *Database {
|
|
db.Context.JoinClause.JoinOn(table, fn)
|
|
return db
|
|
}
|
|
|
|
// LeftJoin clause
|
|
func (db *Database) LeftJoin(table any, argOrFn ...any) *Database {
|
|
db.Context.JoinClause.LeftJoin(table, argOrFn...)
|
|
return db
|
|
}
|
|
|
|
// RightJoin clause
|
|
func (db *Database) RightJoin(table any, argOrFn ...any) *Database {
|
|
db.Context.JoinClause.RightJoin(table, argOrFn...)
|
|
return db
|
|
}
|
|
|
|
// CrossJoin clause
|
|
func (db *Database) CrossJoin(table any, argOrFn ...any) *Database {
|
|
db.Context.JoinClause.CrossJoin(table, argOrFn...)
|
|
return db
|
|
}
|
|
func (db *Database) Where(column any, argsOrclosure ...any) *Database {
|
|
db.Context.WhereClause.Where(column, argsOrclosure...)
|
|
return db
|
|
}
|
|
func (db *Database) OrWhere(column any, argsOrclosure ...any) *Database {
|
|
db.Context.WhereClause.OrWhere(column, argsOrclosure...)
|
|
return db
|
|
}
|
|
|
|
// WhereRaw 在查询中添加一个原生SQL“where”条件。
|
|
//
|
|
// sql: 原生SQL条件字符串。
|
|
// bindings: SQL绑定参数数组。
|
|
func (db *Database) WhereRaw(raw string, bindings ...any) *Database {
|
|
db.Context.WhereClause.WhereRaw(raw, bindings...)
|
|
return db
|
|
}
|
|
func (db *Database) OrWhereRaw(raw string, bindings ...any) *Database {
|
|
db.Context.WhereClause.OrWhereRaw(raw, bindings...)
|
|
return db
|
|
}
|
|
|
|
// GroupBy 添加 GROUP BY 子句
|
|
func (db *Database) GroupBy(columns ...string) *Database {
|
|
db.Context.GroupClause.GroupBy(columns...)
|
|
return db
|
|
}
|
|
func (db *Database) GroupByRaw(columns ...string) *Database {
|
|
db.Context.GroupClause.GroupByRaw(columns...)
|
|
return db
|
|
}
|
|
|
|
// Having 添加 HAVING 子句, 同where
|
|
func (db *Database) Having(column any, argsOrClosure ...any) *Database {
|
|
db.Context.HavingClause.Where(column, argsOrClosure...)
|
|
return db
|
|
}
|
|
func (db *Database) OrHaving(column any, argsOrClosure ...any) *Database {
|
|
db.Context.HavingClause.OrWhere(column, argsOrClosure...)
|
|
return db
|
|
}
|
|
|
|
// HavingRaw 添加 HAVING 子句, 同where
|
|
func (db *Database) HavingRaw(raw string, argsOrClosure ...any) *Database {
|
|
db.Context.HavingClause.WhereRaw(raw, argsOrClosure...)
|
|
return db
|
|
}
|
|
func (db *Database) OrHavingRaw(raw string, argsOrClosure ...any) *Database {
|
|
db.Context.HavingClause.OrWhereRaw(raw, argsOrClosure...)
|
|
return db
|
|
}
|
|
|
|
// OrderBy adds an ORDER BY clause to the query.
|
|
func (db *Database) OrderBy(column string, directions ...string) *Database {
|
|
db.Context.OrderByClause.OrderBy(column, directions...)
|
|
return db
|
|
}
|
|
func (db *Database) OrderByRaw(column string) *Database {
|
|
db.Context.OrderByClause.OrderByRaw(column)
|
|
return db
|
|
}
|
|
|
|
// Limit 设置查询结果的限制数量。
|
|
func (db *Database) Limit(limit int) *Database {
|
|
db.Context.Limit(limit)
|
|
return db
|
|
}
|
|
|
|
// Offset 设置查询结果的偏移量。
|
|
func (db *Database) Offset(offset int) *Database {
|
|
db.Context.Offset(offset)
|
|
return db
|
|
}
|
|
|
|
// Page 页数,根据limit确定
|
|
func (db *Database) Page(num int) *Database {
|
|
db.Context.Page(num)
|
|
return db
|
|
}
|
|
|
|
// SharedLock 4 select ... locking in share mode
|
|
func (db *Database) SharedLock() *Database {
|
|
db.Context.SharedLock()
|
|
return db
|
|
}
|
|
|
|
// LockForUpdate 4 select ... for update
|
|
func (db *Database) LockForUpdate() *Database {
|
|
db.Context.LockForUpdate()
|
|
return db
|
|
}
|
|
|
|
func (db *Database) toBind(bind any) (err error) {
|
|
var prepare string
|
|
var binds []any
|
|
prepare, binds, err = db.ToSql()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
err = db.queryToBindResult(bind, prepare, binds...)
|
|
return
|
|
}
|
|
|
|
// Get 获取查询结果集。
|
|
//
|
|
// columns: 要获取的列名数组,如果不提供,则获取所有列。
|
|
func (db *Database) Get(columns ...string) (res []map[string]any, err error) {
|
|
var prepare string
|
|
var binds []any
|
|
prepare, binds, err = db.Select(columns...).ToSql()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
err = db.queryToBindResult(&res, prepare, binds...)
|
|
return
|
|
}
|
|
func (db *Database) First(columns ...string) (res map[string]any, err error) {
|
|
var prepare string
|
|
var binds []any
|
|
prepare, binds, err = db.Select(columns...).Limit(1).ToSql()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
res = make(map[string]any)
|
|
err = db.queryToBindResult(&res, prepare, binds...)
|
|
return
|
|
}
|
|
func (db *Database) Find(id int) (res map[string]any, err error) {
|
|
var prepare string
|
|
var binds []any
|
|
prepare, binds, err = db.Where("id", id).Limit(1).ToSql()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
res = make(map[string]any)
|
|
err = db.queryToBindResult(&res, prepare, binds...)
|
|
return
|
|
}
|
|
func (db *Database) queryToBindResult(bind any, query string, args ...any) (err error) {
|
|
return db.Engin.QueryTo(bind, query, args...)
|
|
}
|
|
|
|
func (db *Database) insert(obj any, arg builder.TypeToSqlInsertCase) (res sql.Result, err error) {
|
|
//segment, binds, err := db.ToSqlInsert(obj, ignoreCase, onDuplicateKeys, mustColumn...)
|
|
segment, binds, err := db.ToSqlInsert(obj, arg)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
return db.Engin.Exec(segment, binds...)
|
|
}
|
|
func (db *Database) Insert(obj any, mustColumn ...string) (affectedRows int64, err error) {
|
|
result, err := db.insert(obj, builder.TypeToSqlInsertCase{MustColumn: mustColumn})
|
|
if err != nil {
|
|
return affectedRows, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// InsertGetId 插入数据,获取并自增id
|
|
//
|
|
// 参考 https://laravel.com/docs/10.x/queries#auto-incrementing-ids
|
|
func (db *Database) InsertGetId(obj any, mustColumn ...string) (lastInsertId int64, err error) {
|
|
result, err := db.insert(obj, builder.TypeToSqlInsertCase{MustColumn: mustColumn})
|
|
if err != nil {
|
|
return lastInsertId, err
|
|
}
|
|
return result.LastInsertId()
|
|
}
|
|
|
|
// InsertOrIgnore 插入数据,忽略错误。
|
|
//
|
|
// 参考 https://laravel.com/docs/10.x/queries#insert-statements
|
|
func (db *Database) InsertOrIgnore(obj any, mustColumn ...string) (affectedRows int64, err error) {
|
|
result, err := db.insert(obj, builder.TypeToSqlInsertCase{IsIgnoreCase: true, MustColumn: mustColumn})
|
|
if err != nil {
|
|
return affectedRows, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// Upsert 插入数据,如果存在则更新。
|
|
//
|
|
// 参考 https://laravel.com/docs/10.x/queries#upserts
|
|
// 如果是mysql,则不需要填写第二个参数,MySQL会自动处理唯一索引和主键冲突问题
|
|
//
|
|
// eg: Upsert(obj, []string{"id"}, []string{"age"}, "id", "name")
|
|
func (db *Database) Upsert(obj any, onDuplicateKeys, updateFields []string, mustColumn ...string) (affectedRows int64, err error) {
|
|
result, err := db.insert(obj, builder.TypeToSqlInsertCase{OnDuplicateKeys: onDuplicateKeys, UpdateFields: updateFields, MustColumn: mustColumn})
|
|
if err != nil {
|
|
return affectedRows, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// Replace 插入数据,如果存在则替换。
|
|
//
|
|
// 参考 mysql replace into 用法
|
|
func (db *Database) Replace(obj any, mustColumn ...string) (affectedRows int64, err error) {
|
|
result, err := db.insert(obj, builder.TypeToSqlInsertCase{IsReplace: true, MustColumn: mustColumn})
|
|
if err != nil {
|
|
return affectedRows, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// UpdateOrInsert 更新数据,如果存在则更新,否则插入。
|
|
//
|
|
// 参考 https://laravel.com/docs/10.x/queries#update-or-insert
|
|
func (db *Database) UpdateOrInsert(conditions, data map[string]any) (affectedRows int64, err error) {
|
|
dbTmp := db.Where(conditions)
|
|
var exists bool
|
|
if exists, err = dbTmp.Exists(); err != nil {
|
|
return
|
|
}
|
|
if exists {
|
|
return dbTmp.Update(data)
|
|
}
|
|
return dbTmp.Insert(data)
|
|
}
|
|
|
|
func (db *Database) Update(obj any, mustColumn ...string) (affectedRows int64, err error) {
|
|
segment, binds, err := db.ToSqlUpdate(obj, mustColumn...)
|
|
if err != nil {
|
|
return affectedRows, err
|
|
}
|
|
return db.Engin.execute(segment, binds...)
|
|
}
|
|
|
|
func (db *Database) Delete(obj any, mustColumn ...string) (affectedRows int64, err error) {
|
|
segment, binds, err := db.ToSqlDelete(obj, mustColumn...)
|
|
if err != nil {
|
|
return affectedRows, err
|
|
}
|
|
return db.Engin.execute(segment, binds...)
|
|
}
|
|
|
|
func (db *Database) incDecEach(symbol string, data map[string]any) (affectedRows int64, err error) {
|
|
prepare, values, err := db.ToSqlIncDec(symbol, data)
|
|
if err != nil {
|
|
return affectedRows, err
|
|
}
|
|
return db.Engin.execute(prepare, values...)
|
|
}
|
|
func (db *Database) incDec(symbol string, column string, steps ...any) (affectedRows int64, err error) {
|
|
var step any = 1
|
|
if len(steps) > 0 {
|
|
step = steps[0]
|
|
}
|
|
return db.incDecEach(symbol, map[string]any{column: step})
|
|
}
|
|
func (db *Database) Increment(column string, steps ...any) (affectedRows int64, err error) {
|
|
return db.incDec("+", column, steps...)
|
|
}
|
|
func (db *Database) Decrement(column string, steps ...any) (affectedRows int64, err error) {
|
|
return db.incDec("-", column, steps...)
|
|
}
|
|
func (db *Database) IncrementEach(data map[string]any) (affectedRows int64, err error) {
|
|
return db.incDecEach("+", data)
|
|
}
|
|
func (db *Database) DecrementEach(data map[string]any) (affectedRows int64, err error) {
|
|
return db.incDecEach("-", data)
|
|
}
|
|
|
|
// func (db *Database) Aggregate(functions, columns string) (float64, error) {}
|
|
func (db *Database) aggregateSingle(bind any, function, column string) error {
|
|
prepare, values, err := db.ToSqlAggregate(function, column)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return db.Engin.QueryRow(prepare, values...).Scan(bind)
|
|
}
|
|
func (db *Database) Max(column string) (res float64, err error) {
|
|
err = db.aggregateSingle(&res, "max", column)
|
|
return
|
|
}
|
|
func (db *Database) Min(column string) (res float64, err error) {
|
|
err = db.aggregateSingle(&res, "min", column)
|
|
return
|
|
}
|
|
func (db *Database) Sum(column string) (res float64, err error) {
|
|
err = db.aggregateSingle(&res, "sum", column)
|
|
return
|
|
}
|
|
func (db *Database) Avg(column string) (res float64, err error) {
|
|
err = db.aggregateSingle(&res, "avg", column)
|
|
return
|
|
}
|
|
func (db *Database) Count() (res int64, err error) {
|
|
err = db.aggregateSingle(&res, "count", "*")
|
|
return
|
|
}
|
|
|
|
// List 获取指定列的值列表。
|
|
func (db *Database) List(column string) (res []any, err error) {
|
|
ress, err := db.Get(column)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
for _, v := range ress {
|
|
res = append(res, v[column])
|
|
}
|
|
return
|
|
}
|
|
|
|
// Pluck 从查询结果集中获取键值对列表。
|
|
func (db *Database) Pluck(column string, keyColumn string) (res map[any]any, err error) {
|
|
ress, err := db.Get(column, keyColumn)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
res = make(map[any]any)
|
|
for _, v := range ress {
|
|
res[v[keyColumn]] = v[column]
|
|
}
|
|
return
|
|
}
|
|
func (db *Database) Value(column string) (res any, err error) {
|
|
first, err := db.First(column)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
return first[column], err
|
|
}
|
|
func (db *Database) Exists(bind ...any) (b bool, err error) {
|
|
prepare, values, err := db.ToSqlExists(bind...)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
err = db.Engin.QueryRow(prepare, values...).Scan(&b)
|
|
return
|
|
}
|
|
func (db *Database) DoesntExist(bind ...any) (b bool, err error) {
|
|
b, err = db.Exists(bind...)
|
|
return !b, err
|
|
}
|
|
|
|
func (db *Database) Union(b ...builder.IBuilder) *Database {
|
|
db.Context.UnionClause.Union(b...)
|
|
return db
|
|
}
|
|
|
|
func (db *Database) UnionAll(b ...builder.IBuilder) *Database {
|
|
db.Context.UnionClause.UnionAll(b...)
|
|
return db
|
|
}
|
|
|
|
func (db *Database) Truncate(obj ...any) (affectedRows int64, err error) {
|
|
var table string
|
|
var dbTmp = db
|
|
if len(obj) > 0 {
|
|
dbTmp = db.Table(obj[0])
|
|
}
|
|
table, _, err = dbTmp.ToSqlTable()
|
|
if err != nil {
|
|
return
|
|
}
|
|
return db.Engin.execute(fmt.Sprintf("TRUNCATE TABLE %s", table))
|
|
}
|
|
|
|
type TxHandler func() *Database
|
|
|
|
func (db *Database) Begin() (tx TxHandler, err error) {
|
|
return func() *Database {
|
|
db.Context = builder.NewContext(db.prefix)
|
|
return db
|
|
}, db.Engin.Begin()
|
|
}
|
|
|
|
func (db *Database) Transaction(closure ...func(TxHandler) error) error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, v := range closure {
|
|
err = v(tx)
|
|
if err != nil {
|
|
err2 := db.Rollback()
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
return db.Commit()
|
|
}
|
|
|
|
////////////////////// To 相关的操作 //////////////////////
|
|
// 绑定到具体类型上
|
|
// Get(),First(),Find() => To()/Bind()
|
|
// Value() => ValueTo()
|
|
// List() => ListTo()
|
|
// Pluck() => PluckTo()
|
|
// Value() => ValueTo()
|
|
// Max() => MaxTo()
|
|
// Min() => MinTo()
|
|
// Sum() => SumTo()
|
|
|
|
// To 通用查询,go 绑定 struct/map
|
|
func (db *Database) To(obj any, mustColumn ...string) (err error) {
|
|
var prepare string
|
|
var binds []any
|
|
prepare, binds, err = db.ToSqlTo(obj, mustColumn...)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
err = db.queryToBindResult(obj, prepare, binds...)
|
|
return
|
|
}
|
|
|
|
// Bind 查询结果,绑定到结构体
|
|
// 与 To 的区别是,绑定字段不作为查询依据
|
|
// 经常用在join语句中,手动指定查询字段,然后直接绑定到一个结构体
|
|
func (db *Database) Bind(obj any) (err error) {
|
|
var prepare string
|
|
var binds []any
|
|
prepare, binds, err = db.ToSql()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
err = db.queryToBindResult(obj, prepare, binds...)
|
|
return
|
|
}
|
|
|
|
// ListTo 获取指定列的值列表。
|
|
func (db *Database) ListTo(column string, obj any) (err error) {
|
|
return db.Select(column).toBind(obj)
|
|
}
|
|
|
|
// PluckTo 从查询结果集中获取键值对列表。
|
|
func (db *Database) PluckTo(column string, keyColumn string, obj any) (err error) {
|
|
ress, err := db.Get(column, keyColumn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rfv := reflect.Indirect(reflect.ValueOf(obj))
|
|
for _, v := range ress {
|
|
rfv2 := reflect.ValueOf(v)
|
|
keys := rfv2.MapKeys()
|
|
key0 := keys[0].String()
|
|
key1 := keys[1].String()
|
|
if strings.HasSuffix(keyColumn, key0) {
|
|
rfv.SetMapIndex(reflect.ValueOf(v[key0]), reflect.ValueOf(v[key1]))
|
|
} else {
|
|
rfv.SetMapIndex(reflect.ValueOf(v[key1]), reflect.ValueOf(v[key0]))
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// ValueTo 获取指定字段的值,并绑定到给定的变量中
|
|
func (db *Database) ValueTo(column string, obj any) (err error) {
|
|
prepare, values, err := db.Select(column).ToSql()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return db.QueryRow(prepare, values...).Scan(obj)
|
|
}
|
|
|
|
// MaxTo 同 Max
|
|
//
|
|
// obj为具体类型的变量,如: var a int, obj 为 &a, 可以得到具体类型
|
|
func (db *Database) MaxTo(column string, obj any) (err error) {
|
|
err = db.aggregateSingle(obj, "max", column)
|
|
return
|
|
}
|
|
|
|
// MinTo 同 Min, 参考 MaxTo
|
|
func (db *Database) MinTo(column string, obj any) (err error) {
|
|
err = db.aggregateSingle(obj, "min", column)
|
|
return
|
|
}
|
|
|
|
// SumTo 同 Sum, 参考 MaxTo
|
|
func (db *Database) SumTo(column string, obj any) (err error) {
|
|
err = db.aggregateSingle(obj, "sum", column)
|
|
return
|
|
}
|