mirror of
http://github.com/goal-web/database
synced 2025-12-24 10:40:53 +08:00
feat: 优化工厂方法,并且支持非goal app场景下使用本包
This commit is contained in:
18
factory.go
18
factory.go
@@ -3,17 +3,31 @@ package database
|
||||
import (
|
||||
"errors"
|
||||
"github.com/goal-web/contracts"
|
||||
"github.com/goal-web/database/drivers"
|
||||
"github.com/goal-web/supports/utils"
|
||||
)
|
||||
|
||||
type Factory struct {
|
||||
events contracts.EventDispatcher
|
||||
config contracts.Config
|
||||
connections map[string]contracts.DBConnection
|
||||
drivers map[string]contracts.DBConnector
|
||||
dbConfig Config
|
||||
}
|
||||
|
||||
func NewFactory(config Config, events contracts.EventDispatcher) contracts.DBFactory {
|
||||
return &Factory{
|
||||
events: events,
|
||||
dbConfig: config,
|
||||
connections: make(map[string]contracts.DBConnection),
|
||||
drivers: map[string]contracts.DBConnector{
|
||||
"mysql": drivers.MysqlConnector,
|
||||
"postgres": drivers.PostgresSqlConnector,
|
||||
"sqlite": drivers.SqliteConnector,
|
||||
"clickhouse": drivers.ClickHouseConnector,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (factory *Factory) Connection(name ...string) contracts.DBConnection {
|
||||
connection := factory.dbConfig.Default
|
||||
if len(name) > 0 && name[0] != "" {
|
||||
@@ -33,7 +47,7 @@ func (factory *Factory) Extend(name string, driver contracts.DBConnector) {
|
||||
}
|
||||
|
||||
func (factory *Factory) make(name string) contracts.DBConnection {
|
||||
config := factory.config.Get("database").(Config)
|
||||
config := factory.dbConfig
|
||||
|
||||
if connectionConfig, existsConnection := config.Connections[name]; existsConnection {
|
||||
driverName := utils.GetStringField(connectionConfig, "driver")
|
||||
|
||||
@@ -2,11 +2,12 @@ package database
|
||||
|
||||
import (
|
||||
"github.com/goal-web/contracts"
|
||||
"github.com/goal-web/database/drivers"
|
||||
"github.com/goal-web/database/migrations"
|
||||
"github.com/goal-web/database/table"
|
||||
)
|
||||
|
||||
type ServiceProvider struct {
|
||||
app contracts.Application
|
||||
migrations contracts.Migrations
|
||||
}
|
||||
|
||||
@@ -15,6 +16,7 @@ func NewService(migrations contracts.Migrations) contracts.ServiceProvider {
|
||||
}
|
||||
|
||||
func (provider *ServiceProvider) Register(application contracts.Application) {
|
||||
provider.app = application
|
||||
application.Instance("migrations", provider.migrations)
|
||||
application.Singleton("migrations.table", func(config contracts.Config) string {
|
||||
return config.Get("database").(Config).Migrations
|
||||
@@ -22,18 +24,7 @@ func (provider *ServiceProvider) Register(application contracts.Application) {
|
||||
|
||||
application.Singleton("db.factory", func(config contracts.Config) contracts.DBFactory {
|
||||
events, _ := application.Get("events").(contracts.EventDispatcher)
|
||||
return &Factory{
|
||||
events: events,
|
||||
config: config,
|
||||
dbConfig: config.Get("database").(Config),
|
||||
connections: make(map[string]contracts.DBConnection),
|
||||
drivers: map[string]contracts.DBConnector{
|
||||
"mysql": drivers.MysqlConnector,
|
||||
"postgres": drivers.PostgresSqlConnector,
|
||||
"sqlite": drivers.SqliteConnector,
|
||||
"clickhouse": drivers.ClickHouseConnector,
|
||||
},
|
||||
}
|
||||
return NewFactory(config.Get("database").(Config), events)
|
||||
})
|
||||
application.Singleton("db", func(config contracts.Config, factory contracts.DBFactory) contracts.DBConnection {
|
||||
return factory.Connection()
|
||||
@@ -50,6 +41,7 @@ func (provider *ServiceProvider) Register(application contracts.Application) {
|
||||
}
|
||||
|
||||
func (provider *ServiceProvider) Start() error {
|
||||
table.SetFactory(provider.app.Get("db.factory").(contracts.DBFactory))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
53
table/factory.go
Normal file
53
table/factory.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package table
|
||||
|
||||
import (
|
||||
"github.com/goal-web/application"
|
||||
"github.com/goal-web/contracts"
|
||||
"github.com/goal-web/querybuilder"
|
||||
)
|
||||
|
||||
var factory contracts.DBFactory
|
||||
|
||||
func SetFactory(dbFactory contracts.DBFactory) {
|
||||
factory = dbFactory
|
||||
}
|
||||
|
||||
func getFactory() contracts.DBFactory {
|
||||
if factory == nil {
|
||||
factory = application.Get("db.factory").(contracts.DBFactory)
|
||||
}
|
||||
return factory
|
||||
}
|
||||
|
||||
func getTable(name string) *Table {
|
||||
builder := querybuilder.NewQuery(name)
|
||||
instance := &Table{
|
||||
QueryBuilder: builder,
|
||||
primaryKey: "id",
|
||||
table: name,
|
||||
}
|
||||
builder.Bind(instance)
|
||||
return instance
|
||||
}
|
||||
|
||||
// Query 将使用默认 connection
|
||||
func Query(name string) *Table {
|
||||
return getTable(name).SetConnection(factory.Connection())
|
||||
}
|
||||
|
||||
func FromModel(model contracts.Model) *Table {
|
||||
return WithConnection(model.GetTable(), model.GetConnection()).SetClass(model.GetClass()).SetPrimaryKey(model.GetPrimaryKey())
|
||||
}
|
||||
|
||||
// WithConnection 使用指定链接
|
||||
func WithConnection(name string, connection interface{}) *Table {
|
||||
if connection == "" || connection == nil {
|
||||
return Query(name)
|
||||
}
|
||||
return getTable(name).SetConnection(connection)
|
||||
}
|
||||
|
||||
// WithTX 使用TX
|
||||
func WithTX(name string, tx contracts.DBTx) contracts.QueryBuilder {
|
||||
return getTable(name).SetExecutor(tx)
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
package table
|
||||
|
||||
import (
|
||||
"github.com/goal-web/application"
|
||||
"github.com/goal-web/contracts"
|
||||
"github.com/goal-web/querybuilder"
|
||||
"github.com/goal-web/supports/exceptions"
|
||||
"github.com/goal-web/supports/utils"
|
||||
)
|
||||
@@ -17,45 +15,12 @@ type Table struct {
|
||||
class contracts.Class
|
||||
}
|
||||
|
||||
func getTable(name string) *Table {
|
||||
builder := querybuilder.NewQuery(name)
|
||||
instance := &Table{
|
||||
QueryBuilder: builder,
|
||||
primaryKey: "id",
|
||||
table: name,
|
||||
}
|
||||
builder.Bind(instance)
|
||||
return instance
|
||||
}
|
||||
|
||||
// Query 将使用默认 connection
|
||||
func Query(name string) *Table {
|
||||
return getTable(name).SetConnection(application.Get("db").(contracts.DBConnection))
|
||||
}
|
||||
|
||||
func FromModel(model contracts.Model) *Table {
|
||||
return WithConnection(model.GetTable(), model.GetConnection()).SetClass(model.GetClass()).SetPrimaryKey(model.GetPrimaryKey())
|
||||
}
|
||||
|
||||
// WithConnection 使用指定链接
|
||||
func WithConnection(name string, connection interface{}) *Table {
|
||||
if connection == "" || connection == nil {
|
||||
return Query(name)
|
||||
}
|
||||
return getTable(name).SetConnection(connection)
|
||||
}
|
||||
|
||||
// WithTX 使用TX
|
||||
func WithTX(name string, tx contracts.DBTx) contracts.QueryBuilder {
|
||||
return getTable(name).SetExecutor(tx)
|
||||
}
|
||||
|
||||
// SetConnection 参数要么是 contracts.DBConnection 要么是 string
|
||||
func (table *Table) SetConnection(connection interface{}) *Table {
|
||||
if conn, ok := connection.(contracts.DBConnection); ok {
|
||||
table.executor = conn
|
||||
} else {
|
||||
table.executor = application.Get("db.factory").(contracts.DBFactory).Connection(utils.ConvertToString(connection, ""))
|
||||
table.executor = getFactory().Connection(utils.ConvertToString(connection, ""))
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
@@ -61,3 +61,34 @@ func TestMysqlDatabaseService(t *testing.T) {
|
||||
assert.True(t, table.Query("users").Count() == 0)
|
||||
|
||||
}
|
||||
|
||||
func TestMysqlDatabaseWithoutApplication(t *testing.T) {
|
||||
table.SetFactory(database.NewFactory(database.Config{
|
||||
Default: "mysql",
|
||||
Connections: map[string]contracts.Fields{
|
||||
"mysql": {
|
||||
"driver": "mysql",
|
||||
"host": "localhost",
|
||||
"port": "3306",
|
||||
"database": "goal",
|
||||
"username": "root",
|
||||
"password": "123456",
|
||||
"charset": "utf8mb4",
|
||||
"collation": "utf8mb4_unicode_ci",
|
||||
},
|
||||
},
|
||||
Migrations: "migrations",
|
||||
}, nil))
|
||||
|
||||
assert.True(t, table.Query("users").Count() == 0)
|
||||
|
||||
user := table.Query("users").Create(contracts.Fields{
|
||||
"name": "testing",
|
||||
})
|
||||
assert.NotNil(t, user)
|
||||
assert.True(t, user.(contracts.Fields)["name"] == "testing")
|
||||
assert.True(t, table.Query("users").Count() == 1)
|
||||
table.Query("users").Where("name", "testing").Delete()
|
||||
assert.True(t, table.Query("users").Count() == 0)
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user