mirror of
https://github.com/tobycroft/gorose-pro.git
synced 2025-09-26 20:51:27 +08:00
548 lines
14 KiB
Go
548 lines
14 KiB
Go
package gorose
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gohouse/t"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const beginStatus = 0
|
|
|
|
// Session ...
|
|
type Session struct {
|
|
IEngin
|
|
IBinder
|
|
master *sql.DB
|
|
tx *sql.Tx
|
|
slave *sql.DB
|
|
lastInsertId int64
|
|
sqlLogs []string
|
|
lastSql string
|
|
union interface{}
|
|
transaction bool
|
|
tx_index int64
|
|
err error
|
|
}
|
|
|
|
var _ ISession = (*Session)(nil)
|
|
|
|
// NewSession : 初始化 Session
|
|
func NewSession(e IEngin) *Session {
|
|
|
|
var s = new(Session)
|
|
s.IEngin = e
|
|
// 初始化 IBinder
|
|
s.SetIBinder(NewBinder())
|
|
|
|
s.master = e.GetExecuteDB()
|
|
s.slave = e.GetQueryDB()
|
|
//初始化数据
|
|
s.tx_index = beginStatus
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *Session) Close() {
|
|
s.master.Close()
|
|
s.slave.Close()
|
|
}
|
|
|
|
// GetIEngin 获取engin
|
|
func (s *Session) GetIEngin() IEngin {
|
|
return s.IEngin
|
|
}
|
|
|
|
// GetDriver 获取驱动
|
|
func (s *Session) SetIEngin(ie IEngin) {
|
|
s.IEngin = ie
|
|
}
|
|
|
|
// Bind : 传入绑定结果的对象, 参数一为对象, 可以是 struct, gorose.MapRow 或对应的切片
|
|
// 如果是做非query操作,第一个参数也可以仅仅指定为字符串表名
|
|
func (s *Session) Bind(tab interface{}) ISession {
|
|
//fmt.Println(tab, NewBinder(tab))
|
|
//s.SetIBinder(NewBinder(tab))
|
|
s.GetIBinder().SetBindOrigin(tab)
|
|
s.err = s.IBinder.BindParse(s.GetIEngin().GetPrefix())
|
|
return s
|
|
}
|
|
|
|
// GetBinder 获取绑定对象
|
|
func (s *Session) GetErr() error {
|
|
return s.err
|
|
}
|
|
|
|
// GetBinder 获取绑定对象
|
|
func (s *Session) SetIBinder(ib IBinder) {
|
|
s.IBinder = ib
|
|
}
|
|
|
|
// GetBinder 获取绑定对象
|
|
func (s *Session) GetIBinder() IBinder {
|
|
return s.IBinder
|
|
}
|
|
|
|
// GetBinder 获取绑定对象
|
|
func (s *Session) ResetBinderResult() {
|
|
_ = s.IBinder.BindParse(s.GetIEngin().GetPrefix())
|
|
}
|
|
|
|
// GetTableName 获取解析后的名字, 提供给orm使用
|
|
// 为什么要在这里重复添加该方法, 而不是直接继承 IBinder 的方法呢?
|
|
// 是因为, 这里涉及到表前缀的问题, 只能通过session来传递, 所以IOrm就可以选择直接继承
|
|
func (s *Session) GetTableName() (string, error) {
|
|
//err := s.IBinder.BindParse(s.GetIEngin().GetPrefix())
|
|
//fmt.Println(s.GetIBinder())
|
|
return s.GetIBinder().GetBindName(), s.err
|
|
}
|
|
|
|
// Begin ...
|
|
func (s *Session) Begin() (err error) {
|
|
s.SetTransaction(true)
|
|
if s.tx == nil {
|
|
s.tx, err = s.master.Begin()
|
|
} else {
|
|
num := atomic.AddInt64(&s.tx_index, 1)
|
|
err = s.SavePoint("sp_" + strconv.FormatInt(num, 10))
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
} else {
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Rollback ...
|
|
func (s *Session) Rollback() (err error) {
|
|
if s.tx != nil && s.transaction == true {
|
|
if atomic.LoadInt64(&s.tx_index) == beginStatus {
|
|
s.SetTransaction(false)
|
|
err = s.tx.Rollback()
|
|
s.tx = nil
|
|
} else {
|
|
num := atomic.LoadInt64(&s.tx_index)
|
|
err = s.RollbackTo("sp_" + strconv.FormatInt(num, 10))
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
} else {
|
|
atomic.AddInt64(&s.tx_index, -1)
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// RollbackTo ...
|
|
func (s *Session) RollbackTo(savepoint string) (err error) {
|
|
if s.tx != nil && s.transaction == true {
|
|
_, err = s.tx.Exec("rollback to " + savepoint)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *Session) SavePoint(savepoint string) (err error) {
|
|
if s.tx != nil && s.transaction == true {
|
|
_, err = s.tx.Exec("savepoint " + savepoint)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Commit ...
|
|
func (s *Session) Commit() (err error) {
|
|
if s.tx != nil {
|
|
if atomic.LoadInt64(&s.tx_index) == beginStatus {
|
|
s.SetTransaction(false)
|
|
err = s.tx.Commit()
|
|
s.tx = nil
|
|
} else {
|
|
atomic.AddInt64(&s.tx_index, -1)
|
|
}
|
|
} else {
|
|
s.SetTransaction(false)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Transaction ...
|
|
func (s *Session) Transaction(closers ...func(ses ISession) error) (err error) {
|
|
err = s.Begin()
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return err
|
|
}
|
|
|
|
for _, closer := range closers {
|
|
err = closer(s)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
_ = s.Rollback()
|
|
return
|
|
}
|
|
}
|
|
return s.Commit()
|
|
}
|
|
|
|
// Query ...
|
|
func (s *Session) Query(sqlstring string, args ...interface{}) (result []Data, err error) {
|
|
// 记录开始时间
|
|
start := time.Now()
|
|
//withRunTimeContext(func() {
|
|
if s.err != nil {
|
|
err = s.err
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
}
|
|
// 记录sqlLog
|
|
s.lastSql = fmt.Sprint(sqlstring, ", ", args)
|
|
//if s.IfEnableSqlLog() {
|
|
// s.sqlLogs = append(s.sqlLogs, s.lastSql)
|
|
//}
|
|
|
|
var stmt *sql.Stmt
|
|
// 如果是事务, 则从主库中读写
|
|
if s.tx == nil {
|
|
stmt, err = s.slave.Prepare(sqlstring)
|
|
} else {
|
|
stmt, err = s.tx.Prepare(sqlstring)
|
|
}
|
|
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
|
|
defer stmt.Close()
|
|
rows, err := stmt.Query(args...)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
|
|
// make sure we always close rows
|
|
defer rows.Close()
|
|
|
|
err = s.scan(rows)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
//}, func(duration time.Duration) {
|
|
// //if duration.Seconds() > 1 {
|
|
// // s.GetIEngin().GetLogger().Slow(s.LastSql(), duration)
|
|
// //} else {
|
|
// // s.GetIEngin().GetLogger().Sql(s.LastSql(), duration)
|
|
// //}
|
|
//})
|
|
|
|
timeduration := time.Since(start)
|
|
//if timeduration.Seconds() > 1 {
|
|
s.GetIEngin().GetLogger().Slow(s.LastSql(), timeduration)
|
|
//} else {
|
|
s.GetIEngin().GetLogger().Sql(s.LastSql(), timeduration)
|
|
//}
|
|
|
|
result = s.GetIBinder().GetBindAll()
|
|
return
|
|
}
|
|
|
|
// Execute ...
|
|
func (s *Session) Execute(sqlstring string, args ...interface{}) (rowsAffected int64, err error) {
|
|
// 记录开始时间
|
|
start := time.Now()
|
|
//withRunTimeContext(func() {
|
|
// err = s.GetIBinder().BindParse(s.GetIEngin().GetPrefix())
|
|
if s.err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
s.lastSql = fmt.Sprint(sqlstring, ", ", args)
|
|
//// 记录sqlLog
|
|
//if s.IfEnableSqlLog() {
|
|
// s.sqlLogs = append(s.sqlLogs, s.lastSql)
|
|
//}
|
|
|
|
var operType = strings.ToLower(sqlstring[0:6])
|
|
if operType == "select" {
|
|
err = errors.New("Execute does not allow select operations, please use Query")
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
|
|
var stmt *sql.Stmt
|
|
if s.tx == nil {
|
|
stmt, err = s.master.Prepare(sqlstring)
|
|
} else {
|
|
stmt, err = s.tx.Prepare(sqlstring)
|
|
}
|
|
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
|
|
//var err error
|
|
defer stmt.Close()
|
|
result, err := stmt.Exec(args...)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return
|
|
}
|
|
|
|
if operType == "insert" || operType == "replace" {
|
|
// get last insert id
|
|
lastInsertId, err := result.LastInsertId()
|
|
if err == nil {
|
|
s.lastInsertId = lastInsertId
|
|
} else {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
}
|
|
}
|
|
// get rows affected
|
|
rowsAffected, err = result.RowsAffected()
|
|
timeduration := time.Since(start)
|
|
//}, func(duration time.Duration) {
|
|
if timeduration.Seconds() > 1 {
|
|
s.GetIEngin().GetLogger().Slow(s.LastSql(), timeduration)
|
|
} else {
|
|
s.GetIEngin().GetLogger().Sql(s.LastSql(), timeduration)
|
|
}
|
|
//})
|
|
return
|
|
}
|
|
|
|
// LastInsertId ...
|
|
func (s *Session) LastInsertId() int64 {
|
|
return s.lastInsertId
|
|
}
|
|
|
|
// LastSql ...
|
|
func (s *Session) LastSql() string {
|
|
return s.lastSql
|
|
}
|
|
|
|
func (s *Session) scan(rows *sql.Rows) (err error) {
|
|
// 如果不需要绑定, 则需要初始化一下binder
|
|
if s.GetIBinder() == nil {
|
|
s.SetIBinder(NewBinder())
|
|
}
|
|
// 检查实多维数组还是一维数组
|
|
switch s.GetBindType() {
|
|
case OBJECT_STRING:
|
|
err = s.scanAll(rows)
|
|
case OBJECT_STRUCT, OBJECT_STRUCT_SLICE:
|
|
err = s.scanStructAll(rows)
|
|
//case OBJECT_MAP, OBJECT_MAP_T:
|
|
// err = s.scanMap(rows, s.GetBindResult())
|
|
case OBJECT_MAP, OBJECT_MAP_T, OBJECT_MAP_SLICE, OBJECT_MAP_SLICE_T:
|
|
err = s.scanMapAll(rows)
|
|
case OBJECT_NIL:
|
|
err = s.scanAll(rows)
|
|
default:
|
|
err = errors.New("Bind value error")
|
|
}
|
|
return
|
|
}
|
|
|
|
//func (s *Session) scanMap(rows *sql.Rows, dst interface{}) (err error) {
|
|
// return s.scanMapAll(rows, dst)
|
|
//}
|
|
|
|
func (s *Session) scanMapAll(rows *sql.Rows) (err error) {
|
|
var columns []string
|
|
// 获取查询的所有字段
|
|
if columns, err = rows.Columns(); err != nil {
|
|
return
|
|
}
|
|
count := len(columns)
|
|
|
|
for rows.Next() {
|
|
// 定义要绑定的结果集
|
|
values := make([]interface{}, count)
|
|
scanArgs := make([]interface{}, count)
|
|
for i := 0; i < count; i++ {
|
|
scanArgs[i] = &values[i]
|
|
}
|
|
// 获取结果
|
|
_ = rows.Scan(scanArgs...)
|
|
|
|
// 定义预设的绑定对象
|
|
//fmt.Println(reflect.TypeOf(s.GetBindResult()).Kind())
|
|
var bindResultTmp = reflect.MakeMap(reflect.Indirect(reflect.ValueOf(s.GetBindResult())).Type())
|
|
//// 定义union操作的map返回
|
|
//var unionTmp = map[string]interface{}{}
|
|
for i, col := range columns {
|
|
var v interface{}
|
|
val := values[i]
|
|
if b, ok := val.([]byte); ok {
|
|
v = string(b)
|
|
} else {
|
|
v = val
|
|
}
|
|
// 如果是union操作就不需要绑定数据直接返回, 否则就绑定数据
|
|
//TODO 这里可能有点问题, 比如在group时, 返回的结果不止一条, 这里直接返回的就是第一条
|
|
// 默认其实只是取了第一条, 满足常规的 union 操作(count,sum,max,min,avg)而已
|
|
// 后边需要再行完善, 以便group时使用
|
|
// 具体完善方法: 就是这里断点去掉, 不直接绑定union, 新增一个map,将结果放在map中,在方法最后统一返回
|
|
if s.GetUnion() != nil {
|
|
s.union = v
|
|
return
|
|
// 以下上通用解决方法
|
|
//unionTmp[col] = v
|
|
//s.union = unionTmp
|
|
} else {
|
|
br := reflect.Indirect(reflect.ValueOf(s.GetBindResult()))
|
|
switch s.GetBindType() {
|
|
case OBJECT_MAP_T, OBJECT_MAP_SLICE_T: // t.T类型
|
|
// 绑定到一条数据结果对象上,方便其他地方的调用,永远存储最新一条
|
|
br.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(t.New(v)))
|
|
// 跟上一行干的事是一样的, 只不过防止上一行的数据被后续的数据改变, 而无法提供给下边多条数据报错的需要
|
|
if s.GetBindType() == OBJECT_MAP_SLICE || s.GetBindType() == OBJECT_MAP_SLICE_T {
|
|
|
|
bindResultTmp.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(t.New(v)))
|
|
}
|
|
default: // 普通类型map[string]interface{}, 具体代码注释参照 上一个 case
|
|
br.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(v))
|
|
if s.GetBindType() == OBJECT_MAP_SLICE || s.GetBindType() == OBJECT_MAP_SLICE_T {
|
|
bindResultTmp.SetMapIndex(reflect.ValueOf(col), reflect.ValueOf(v))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// 如果是union操作就不需要绑定数据直接返回, 否则就绑定数据
|
|
if s.GetUnion() == nil {
|
|
// 如果是多条数据集, 就插入到对应的结果集slice上
|
|
if s.GetBindType() == OBJECT_MAP_SLICE || s.GetBindType() == OBJECT_MAP_SLICE_T {
|
|
s.GetBindResultSlice().Set(reflect.Append(s.GetBindResultSlice(), bindResultTmp))
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// ScanAll scans all sql result rows into a slice of structs.
|
|
// It reads all rows and closes rows when finished.
|
|
// dst should be a pointer to a slice of the appropriate type.
|
|
// The new results will be appended to any existing data in dst.
|
|
func (s *Session) scanStructAll(rows *sql.Rows) error {
|
|
// check if there is data waiting
|
|
//if !rows.Next() {
|
|
// if err := rows.Err(); err != nil {
|
|
// s.GetIEngin().GetLogger().Error(err.Error())
|
|
// return err
|
|
// }
|
|
// return sql.ErrNoRows
|
|
//}
|
|
var sfs = structForScan(s.GetBindResult())
|
|
for rows.Next() {
|
|
if s.GetUnion() != nil {
|
|
var union interface{}
|
|
err := rows.Scan(&union)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return err
|
|
}
|
|
s.union = union
|
|
return err
|
|
}
|
|
// scan it
|
|
//fmt.Printf("%#v \n",structForScan(s.GetBindResult()))
|
|
err := rows.Scan(sfs...)
|
|
if err != nil {
|
|
s.GetIEngin().GetLogger().Error(err.Error())
|
|
return err
|
|
}
|
|
// 如果是union操作就不需要绑定数据直接返回, 否则就绑定数据
|
|
if s.GetUnion() == nil {
|
|
// 如果是多条数据集, 就插入到对应的结果集slice上
|
|
if s.GetBindType() == OBJECT_STRUCT_SLICE {
|
|
// add to the result slice
|
|
s.GetBindResultSlice().Set(reflect.Append(s.GetBindResultSlice(),
|
|
reflect.Indirect(reflect.ValueOf(s.GetBindResult()))))
|
|
}
|
|
}
|
|
}
|
|
|
|
return rows.Err()
|
|
}
|
|
|
|
func (s *Session) scanAll(rows *sql.Rows) (err error) {
|
|
var columns []string
|
|
// 获取查询的所有字段
|
|
if columns, err = rows.Columns(); err != nil {
|
|
return
|
|
}
|
|
count := len(columns)
|
|
|
|
var result = []Data{}
|
|
for rows.Next() {
|
|
// 定义要绑定的结果集
|
|
values := make([]interface{}, count)
|
|
scanArgs := make([]interface{}, count)
|
|
for i := 0; i < count; i++ {
|
|
scanArgs[i] = &values[i]
|
|
}
|
|
// 获取结果
|
|
_ = rows.Scan(scanArgs...)
|
|
|
|
// 定义预设的绑定对象
|
|
var resultTmp = Data{}
|
|
//// 定义union操作的map返回
|
|
//var unionTmp = map[string]interface{}{}
|
|
for i, col := range columns {
|
|
var v interface{}
|
|
val := values[i]
|
|
if b, ok := val.([]byte); ok {
|
|
v = string(b)
|
|
} else {
|
|
v = val
|
|
}
|
|
if s.GetUnion() != nil {
|
|
s.union = v
|
|
return
|
|
// 以下上通用解决方法
|
|
//unionTmp[col] = v
|
|
//s.union = unionTmp
|
|
}
|
|
resultTmp[col] = v
|
|
}
|
|
result = append(result, resultTmp)
|
|
}
|
|
s.IBinder.SetBindAll(result)
|
|
return
|
|
}
|
|
|
|
// SetUnion ...
|
|
func (s *Session) SetUnion(u interface{}) {
|
|
s.union = u
|
|
}
|
|
|
|
// GetUnion ...
|
|
func (s *Session) GetUnion() interface{} {
|
|
return s.union
|
|
}
|
|
|
|
// SetTransaction ...
|
|
func (s *Session) SetTransaction(b bool) {
|
|
s.transaction = b
|
|
}
|
|
|
|
// GetTransaction 提供给 orm 使用的, 方便reset操作
|
|
func (s *Session) GetTransaction() bool {
|
|
return s.transaction
|
|
}
|