feat: support config multiple databases (#123)

* feat: support config multiple databases
This commit is contained in:
Richard
2023-12-24 16:18:08 +08:00
committed by GitHub
parent 4ad255ebac
commit aa3b835575
9 changed files with 170 additions and 54 deletions

View File

@@ -2,6 +2,7 @@
## v1.8.2
- feat: support PostgreSQL
- feat: support config multiple databases
## v1.8.1
- fix: GitHub workflow badge URL

View File

@@ -1,10 +1,11 @@
Driver: mysql # 驱动名称,目前支持 mysqlpostgres默认: mysql
Name: eagle # 数据库名称
Addr: db:3306 # 如果是 docker,可以替换为 对应的服务名称eg: db:3306
UserName: root
Password: root
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数0意味着使用默认的大小2 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值设置后只打印慢查询日志默认为200ms
default:
Driver: mysql # 驱动名称,目前支持 mysqlpostgres默认: mysql
Name: eagle # 数据库名称
Addr: db:3306 # 如果是 docker,可以替换为 对应的服务名称eg: db:3306
UserName: root
Password: root
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数0意味着使用默认的大小2 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值设置后只打印慢查询日志默认为200ms

View File

@@ -1,10 +1,22 @@
Driver: mysql # 驱动名称,目前支持: mysqlpostgres默认: mysql
Name: eagle # 数据库名称
Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称eg: db:3306, pg:5432
UserName: root
Password: 123456
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数0意味着使用默认的大小2 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值设置后只打印慢查询日志默认为200ms
default:
Driver: mysql # 驱动名称,目前支持: mysqlpostgres默认: mysql
Name: eagle # 数据库名称
Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称eg: db:3306, pg:5432
UserName: root
Password: 123456
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数0意味着使用默认的大小2 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值设置后只打印慢查询日志默认为200ms
user:
Driver: mysql # 驱动名称,目前支持: mysqlpostgres默认: mysql
Name: eagle # 数据库名称
Addr: localhost:3306 # 如果是 docker,可以替换为 对应的服务名称eg: db:3306, pg:5432
UserName: root
Password: 123456
ShowLog: true # 是否打印所有SQL日志
MaxIdleConn: 10 # 最大闲置的连接数0意味着使用默认的大小2 小于0表示不使用连接池
MaxOpenConn: 60 # 最大打开的连接数, 需要小于数据库配置中的max_connections数
ConnMaxLifeTime: 4h # 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些
SlowThreshold: 500ms # 慢查询阈值设置后只打印慢查询日志默认为200ms

View File

@@ -2,7 +2,6 @@ package user
import (
"errors"
"time"
"github.com/gin-gonic/gin"
"github.com/spf13/cast"
@@ -44,7 +43,7 @@ func Get(c *gin.Context) {
return
}
time.Sleep(5 * time.Second)
//time.Sleep(5 * time.Second)
response.Success(c, u)
}

View File

@@ -5,35 +5,33 @@ import (
"gorm.io/gorm"
"github.com/go-eagle/eagle/pkg/config"
"github.com/go-eagle/eagle/pkg/storage/orm"
)
// DB 数据库全局变量
var DB *gorm.DB
const (
// DefaultDatabase default database
DefaultDatabase = "default"
// UserDatabase user database
UserDatabase = "user"
)
// Init 初始化数据库
func Init() *gorm.DB {
cfg, err := loadConf()
func Init() {
err := orm.New(
DefaultDatabase,
UserDatabase,
)
if err != nil {
panic(fmt.Sprintf("load orm conf err: %v", err))
panic(fmt.Sprintf("new orm database err: %v", err))
}
DB = orm.New(cfg)
return DB
}
// GetDB 返回默认的数据库
func GetDB() *gorm.DB {
return DB
func GetDB() (*gorm.DB, error) {
return orm.GetDB(DefaultDatabase)
}
// loadConf load database config
func loadConf() (ret *orm.Config, err error) {
var cfg orm.Config
if err := config.Load("database", &cfg); err != nil {
return nil, err
}
return &cfg, nil
// GetUserDB 获取用户数据库实例
func GetUserDB() (*gorm.DB, error) {
return orm.GetDB(UserDatabase)
}

View File

@@ -40,7 +40,7 @@ func (d *repository) UpdateUserFansStatus(ctx context.Context, db *gorm.DB, user
// GetFollowingUserList .
func (d *repository) GetFollowingUserList(ctx context.Context, userID, lastID uint64, limit int) ([]*model.UserFollowModel, error) {
userFollowList := make([]*model.UserFollowModel, 0)
db := model.GetDB()
db, _ := model.GetDB()
result := db.Where("user_id=? AND id<=? and status=1", userID, lastID).
Order("id desc").
Limit(limit).Find(&userFollowList)
@@ -56,7 +56,7 @@ func (d *repository) GetFollowingUserList(ctx context.Context, userID, lastID ui
// GetFollowerUserList get follower user list
func (d *repository) GetFollowerUserList(ctx context.Context, userID, lastID uint64, limit int) ([]*model.UserFansModel, error) {
userFollowerList := make([]*model.UserFansModel, 0)
db := model.GetDB()
db, _ := model.GetDB()
result := db.Where("user_id=? AND id<=? and status=1", userID, lastID).
Order("id desc").
Limit(limit).Find(&userFollowerList)
@@ -73,8 +73,8 @@ func (d *repository) GetFollowerUserList(ctx context.Context, userID, lastID uin
func (d *repository) GetFollowByUIds(ctx context.Context, userID uint64, followingUID []uint64) (map[uint64]*model.UserFollowModel, error) {
userFollowModel := make([]*model.UserFollowModel, 0)
retMap := make(map[uint64]*model.UserFollowModel)
err := model.GetDB().
db, _ := model.GetDB()
err := db.
Where("user_id=? AND followed_uid in (?) ", userID, followingUID).
Find(&userFollowModel).Error
@@ -93,12 +93,12 @@ func (d *repository) GetFollowByUIds(ctx context.Context, userID uint64, followi
func (d *repository) GetFansByUIds(ctx context.Context, userID uint64, followerUID []uint64) (map[uint64]*model.UserFansModel, error) {
userFansModel := make([]*model.UserFansModel, 0)
retMap := make(map[uint64]*model.UserFansModel)
err := model.GetDB().
db, _ := model.GetDB()
err := db.
Where("user_id=? AND follower_uid in (?) ", userID, followerUID).
Find(&userFansModel).Error
if err != nil && err != gorm.ErrRecordNotFound {
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return retMap, errors.Wrap(err, "[user_follow] get user fans err")
}

View File

@@ -39,7 +39,8 @@ func newRelations(svc *service) *relationService {
// IsFollowing 是否正在关注某用户
func (s *relationService) IsFollowing(ctx context.Context, userID uint64, followedUID uint64) bool {
userFollowModel := &model.UserFollowModel{}
result := model.GetDB().
db, _ := model.GetDB()
result := db.
Where("user_id=? AND followed_uid=? ", userID, followedUID).
Find(userFollowModel)
@@ -57,7 +58,7 @@ func (s *relationService) IsFollowing(ctx context.Context, userID uint64, follow
// Follow 关注目标用户
func (s *relationService) Follow(ctx context.Context, userID uint64, followedUID uint64) error {
db := model.GetDB()
db, _ := model.GetDB()
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
@@ -103,7 +104,7 @@ func (s *relationService) Follow(ctx context.Context, userID uint64, followedUID
// Unfollow 取消用户关注
func (s *relationService) Unfollow(ctx context.Context, userID uint64, followedUID uint64) error {
db := model.GetDB()
db, _ := model.GetDB()
tx := db.Begin()
defer func() {
if r := recover(); r != nil {

View File

@@ -76,7 +76,8 @@ func main() {
// redis.Init()
// init service
service.Svc = service.New(repository.New(model.GetDB()))
db, _ := model.GetDB()
service.Svc = service.New(repository.New(db))
gin.SetMode(cfg.Mode)

View File

@@ -5,8 +5,11 @@ import (
"fmt"
"log"
"os"
"sync"
"time"
"github.com/go-eagle/eagle/pkg/config"
otelgorm "github.com/1024casts/gorm-opentelemetry"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
@@ -21,6 +24,16 @@ const (
DriverMySQL = "mysql"
// DriverPostgres postgresSQL driver
DriverPostgres = "postgres"
// DefaultDatabase default db name
DefaultDatabase = "default"
)
var (
// DBMap store database instance
DBMap = make(map[string]*gorm.DB)
// DBLock database locker
DBLock sync.Mutex
)
// Config database config
@@ -37,8 +50,82 @@ type Config struct {
SlowThreshold time.Duration // 慢查询时长默认500ms
}
// New connect to database and create a db instance
func New(c *Config) (db *gorm.DB) {
// New create a or multi database client
func New(names ...string) error {
if len(names) == 0 {
return fmt.Errorf("no set databasename")
}
clientManager := NewManager()
for _, name := range names {
_, err := clientManager.GetInstance(name)
if err != nil {
return fmt.Errorf("init database name: %+v, err: %+v", name, err)
}
}
return nil
}
// Manager define a manager
type Manager struct {
instances map[string]*gorm.DB
*sync.RWMutex
}
// NewManager create a database manager
func NewManager() *Manager {
return &Manager{
instances: make(map[string]*gorm.DB),
RWMutex: &sync.RWMutex{},
}
}
// GetDB get a database
func GetDB(name string) (*gorm.DB, error) {
DBLock.Lock()
defer DBLock.Unlock()
db, ok := DBMap[name]
if !ok {
db, err := NewManager().GetInstance(name)
if err != nil {
return nil, err
}
return db, nil
}
return db, nil
}
// GetInstance return a database client
func (m *Manager) GetInstance(name string) (*gorm.DB, error) {
// get client from map
m.RLock()
if ins, ok := m.instances[name]; ok {
m.RUnlock()
return ins, nil
}
m.RUnlock()
c, err := LoadConf(name)
if err != nil {
return nil, fmt.Errorf("load database conf err: %+v", err)
}
// create a database client
m.Lock()
defer m.Unlock()
instance := NewInstance(c)
m.instances[name] = instance
DBMap[name] = instance
return instance, nil
}
// NewInstance connect to database and create a db instance
func NewInstance(c *Config) (db *gorm.DB) {
var (
err error
sqlDB *sql.DB
@@ -83,6 +170,22 @@ func New(c *Config) (db *gorm.DB) {
return db
}
// LoadConf load database config
func LoadConf(name string) (ret *Config, err error) {
v, err := config.LoadWithType("database", "yaml")
if err != nil {
return nil, err
}
var c Config
err = v.UnmarshalKey(name, &c)
if err != nil {
return nil, err
}
return &c, nil
}
// getDSN return dsn string
func getDSN(c *Config) string {
// default mysql