320 lines
8.6 KiB
Go
320 lines
8.6 KiB
Go
package db
|
||
|
||
/**
|
||
* 数据库连接工具
|
||
* @author ZStudio
|
||
* @since 2021/9/8
|
||
* @File : db
|
||
*/
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"reflect"
|
||
"strings"
|
||
"time"
|
||
|
||
"xorm.io/core"
|
||
"xorm.io/xorm"
|
||
"xorm.io/xorm/log"
|
||
"xorm.io/xorm/names"
|
||
"xorm.io/xorm/schemas"
|
||
|
||
_ "github.com/denisenkom/go-mssqldb"
|
||
_ "github.com/go-sql-driver/mysql"
|
||
_ "modernc.org/sqlite"
|
||
)
|
||
|
||
var (
|
||
db *xorm.Engine
|
||
tables []any
|
||
initFuncs []func() error
|
||
)
|
||
|
||
func init() {
|
||
gonicNames := []string{"SSL", "UID"}
|
||
for _, name := range gonicNames {
|
||
names.LintGonicMapper[name] = true
|
||
}
|
||
}
|
||
|
||
type Database struct {
|
||
Type string
|
||
Host string
|
||
Port int
|
||
User string
|
||
Passwd string
|
||
Path string
|
||
Prev string
|
||
Debug bool
|
||
Migration bool
|
||
}
|
||
|
||
// Engine 表示 xorm 引擎或会话。
|
||
type Engine interface {
|
||
Table(tableNameOrBean any) *xorm.Session
|
||
Count(...any) (int64, error)
|
||
Decr(column string, arg ...any) *xorm.Session
|
||
Delete(...any) (int64, error)
|
||
Truncate(...any) (int64, error)
|
||
Exec(...any) (sql.Result, error)
|
||
Find(any, ...any) error
|
||
Get(beans ...any) (bool, error)
|
||
ID(any) *xorm.Session
|
||
In(string, ...any) *xorm.Session
|
||
Incr(column string, arg ...any) *xorm.Session
|
||
Insert(...any) (int64, error)
|
||
Iterate(any, xorm.IterFunc) error
|
||
Join(joinOperator string, tablename, condition any, args ...any) *xorm.Session
|
||
SQL(any, ...any) *xorm.Session
|
||
Where(any, ...any) *xorm.Session
|
||
Asc(colNames ...string) *xorm.Session
|
||
Desc(colNames ...string) *xorm.Session
|
||
Limit(limit int, start ...int) *xorm.Session
|
||
NoAutoTime() *xorm.Session
|
||
SumInt(bean any, columnName string) (res int64, err error)
|
||
Sync(...any) error
|
||
Select(string) *xorm.Session
|
||
NotIn(string, ...any) *xorm.Session
|
||
OrderBy(any, ...any) *xorm.Session
|
||
Exist(...any) (bool, error)
|
||
Distinct(...string) *xorm.Session
|
||
Query(...any) ([]map[string][]byte, error)
|
||
Cols(...string) *xorm.Session
|
||
Context(ctx context.Context) *xorm.Session
|
||
Ping() error
|
||
}
|
||
|
||
// InitEngine 初始化 xorm.Engine 并将其设置为 db.DefaultContext
|
||
func InitEngine(ctx context.Context, conf *Database) error {
|
||
_engine, err := newXORMEngine(conf)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
_engine.DatabaseTZ = time.Local // 必须
|
||
_engine.TZLocation = time.Local // 必须
|
||
|
||
_engine.SetMaxOpenConns(50) // 连接池中最大连接数
|
||
_engine.SetMaxIdleConns(10) // 连接池中最大空闲连接数
|
||
_engine.SetConnMaxLifetime(time.Second * 10) // 单个连接最大存活时间(单位:秒)
|
||
_engine.SetConnMaxIdleTime(time.Second * 5) // 设置连接可能处于空闲状态的最长时间
|
||
_engine.SetDefaultContext(ctx)
|
||
|
||
SetDefaultEngine(ctx, _engine)
|
||
return nil
|
||
}
|
||
|
||
// newXORMEngine 从配置返回一个新的 XORM 引擎
|
||
func newXORMEngine(config *Database) (*xorm.Engine, error) {
|
||
driverName := "sqlite"
|
||
dsbSource := config.Path
|
||
|
||
if config.Type != "" {
|
||
driverName = config.Type
|
||
dsbSource = fmt.Sprintf(
|
||
"%v:%v@tcp(%v:%v)/%v?charset=utf8&parseTime=True&loc=Local&timeout=1000ms",
|
||
config.User,
|
||
config.Passwd,
|
||
config.Host,
|
||
config.Port,
|
||
config.Path,
|
||
)
|
||
}
|
||
|
||
x, err := xorm.NewEngine(driverName, dsbSource)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if driverName == "mysql" {
|
||
x.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
|
||
} else if driverName == "mssql" {
|
||
x.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"})
|
||
}
|
||
//x.SetSchema(config.Database.Schema)
|
||
// 通过engine.Ping()来进行数据库的连接测试是否可以连接到数据库。
|
||
if err = x.Ping(); err != nil {
|
||
return nil, errors.New("数据库连接错误")
|
||
}
|
||
|
||
// 结构体与数据表的映射
|
||
tbMapper := core.NewPrefixMapper(core.SnakeMapper{}, config.Prev)
|
||
x.SetTableMapper(tbMapper)
|
||
|
||
// 开启调试模式和打印日志,会在控制台打印执行的sql
|
||
if config.Debug {
|
||
x.ShowSQL(true)
|
||
x.Logger().SetLevel(log.LOG_DEBUG)
|
||
} else {
|
||
x.Logger().SetLevel(log.LOG_OFF)
|
||
}
|
||
|
||
return x, nil
|
||
}
|
||
|
||
// TableInfo 通过对象返回表的信息
|
||
func TableInfo(v any) (*schemas.Table, error) {
|
||
return db.TableInfo(v)
|
||
}
|
||
|
||
// DumpTables 转储表信息
|
||
func DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
|
||
return db.DumpTables(tables, w, tp...)
|
||
}
|
||
|
||
// RegisterModel 注册模型,如果提供了 initfunc,则在数据模型同步后将调用它
|
||
func RegisterModel(bean any, initFunc ...func() error) {
|
||
tables = append(tables, bean)
|
||
if len(initFuncs) > 0 && initFunc[0] != nil {
|
||
initFuncs = append(initFuncs, initFunc[0])
|
||
}
|
||
}
|
||
|
||
// SyncAllTables 同步所有表的模式,由单元测试代码需要
|
||
func SyncAllTables() error {
|
||
_, err := db.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{
|
||
WarnIfDatabaseColumnMissed: true,
|
||
}, tables...)
|
||
return err
|
||
}
|
||
|
||
// SetDefaultEngine 设置 db 的默认引擎
|
||
func SetDefaultEngine(c context.Context, eng *xorm.Engine) {
|
||
db = eng
|
||
DefaultContext = &Context{
|
||
Context: c,
|
||
e: eng,
|
||
}
|
||
}
|
||
|
||
// UnsetDefaultEngine 关闭并取消设置默认引擎
|
||
// 我们希望 SetDefaultEngine 和 UnsetDefaultEngine 可以成对出现,但现在不可能,
|
||
// 现在全局数据库引擎相关函数都是竞争的,现在没有优雅的关闭。
|
||
func UnsetDefaultEngine() {
|
||
if db != nil {
|
||
_ = db.Close()
|
||
}
|
||
DefaultContext = nil
|
||
}
|
||
|
||
// InitEngineWithMigration 初始化一个新的 xorm.Engine 并将其设置为 db.DefaultContext
|
||
// 如果提供的迁移函数失败,则此函数绝不能调用 .Sync()。
|
||
// 当从 "doctor" 命令调用时,迁移函数是一个版本检查,
|
||
// 如果迁移级别与预期值不同,则阻止 doctor 修复数据库中的任何内容。
|
||
func InitEngineWithMigration(ctx context.Context, config *Database, migrateFunc func(*xorm.Engine) error) (err error) {
|
||
if err = InitEngine(ctx, config); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err = db.Ping(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 我们必须在此处运行 migrateFunc,以防用户在以前创建的数据库上重新运行安装。
|
||
// 如果我们不这样做,那么表模式将被更改,并且当正确运行迁移时将会发生冲突。
|
||
//
|
||
// 只有在用户想要恢复旧数据库时才应该重新运行安装。
|
||
// 但是,我们应该仔细考虑是否应支持在已安装的实例上重新安装,
|
||
// 因为可能由于秘密重新初始化而导致其他问题。
|
||
if err = migrateFunc(db); err != nil {
|
||
return fmt.Errorf("migrate: %w", err)
|
||
}
|
||
|
||
if err = SyncAllTables(); err != nil {
|
||
return fmt.Errorf("sync database struct error: %w", err)
|
||
}
|
||
|
||
for _, initFunc := range initFuncs {
|
||
if err := initFunc(); err != nil {
|
||
return fmt.Errorf("initFunc failed: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// NamesToBean 返回一个 bean 列表或一个错误
|
||
func NamesToBean(names ...string) ([]any, error) {
|
||
beans := []any{}
|
||
if len(names) == 0 {
|
||
beans = append(beans, tables...)
|
||
return beans, nil
|
||
}
|
||
// 需要将提供的名称映射到 bean...
|
||
beanMap := make(map[string]any)
|
||
for _, bean := range tables {
|
||
|
||
beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean
|
||
beanMap[strings.ToLower(db.TableName(bean))] = bean
|
||
beanMap[strings.ToLower(db.TableName(bean, true))] = bean
|
||
}
|
||
|
||
gotBean := make(map[any]bool)
|
||
for _, name := range names {
|
||
bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))]
|
||
if !ok {
|
||
return nil, fmt.Errorf("no table found that matches: %s", name)
|
||
}
|
||
if !gotBean[bean] {
|
||
beans = append(beans, bean)
|
||
gotBean[bean] = true
|
||
}
|
||
}
|
||
return beans, nil
|
||
}
|
||
|
||
// DumpDatabase 根据特殊数据库 SQL 语法将数据库中的所有数据转储到文件系统。
|
||
func DumpDatabase(filePath, dbType string) error {
|
||
var tbs []*schemas.Table
|
||
for _, t := range tables {
|
||
t, err := db.TableInfo(t)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
tbs = append(tbs, t)
|
||
}
|
||
|
||
type Version struct {
|
||
ID int64 `xorm:"pk autoincr"`
|
||
Version int64
|
||
}
|
||
t, err := db.TableInfo(&Version{})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
tbs = append(tbs, t)
|
||
|
||
if len(dbType) > 0 {
|
||
return db.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType))
|
||
}
|
||
return db.DumpTablesToFile(tbs, filePath)
|
||
}
|
||
|
||
// MaxBatchInsertSize 返回表的最大批量插入大小
|
||
func MaxBatchInsertSize(bean any) int {
|
||
t, err := db.TableInfo(bean)
|
||
if err != nil {
|
||
return 50
|
||
}
|
||
return 999 / len(t.ColumnsSeq())
|
||
}
|
||
|
||
// IsTableNotEmpty 如果表至少有一条记录,则返回 true
|
||
func IsTableNotEmpty(tableName string) (bool, error) {
|
||
return db.Table(tableName).Exist()
|
||
}
|
||
|
||
// DeleteAllRecords 将删除此表的所有记录
|
||
func DeleteAllRecords(tableName string) error {
|
||
_, err := db.Exec(fmt.Sprintf("DELETE FROM %s", tableName))
|
||
return err
|
||
}
|
||
|
||
// GetMaxID 将返回表的最大 id
|
||
func GetMaxID(beanOrTableName any) (maxID int64, err error) {
|
||
_, err = db.Select("MAX(id)").Table(beanOrTableName).Get(&maxID)
|
||
return maxID, err
|
||
}
|