From a7fd5bdbe5af8c60a0f66186a28408e3f9af538e Mon Sep 17 00:00:00 2001 From: Wei Lun Date: Sat, 31 Oct 2020 17:39:12 +0800 Subject: [PATCH] update mongo storage --- mongodb/README.md | 4 ++++ mongodb/config.go | 10 ++++++---- mongodb/mongodb.go | 35 +++++++++++++++++++++++++---------- mongodb/mongodb_test.go | 40 ++++++++++++++-------------------------- 4 files changed, 49 insertions(+), 40 deletions(-) diff --git a/mongodb/README.md b/mongodb/README.md index 5afc4139..ff8cd049 100644 --- a/mongodb/README.md +++ b/mongodb/README.md @@ -18,6 +18,10 @@ func main() { Database: "_database", Collection: "_storage", }) + + // Access DB connection + // for disconnet for example + store.DB.Client().Disconnect(context.TODO()) } ``` diff --git a/mongodb/config.go b/mongodb/config.go index 67359d43..3896281f 100644 --- a/mongodb/config.go +++ b/mongodb/config.go @@ -15,7 +15,9 @@ import ( // Config defines the config for storage. type Config struct { // Custom options - Addr string + + //https://docs.mongodb.com/manual/reference/connection-string/ + URI string Database string Collection string @@ -52,15 +54,15 @@ type Config struct { // ConfigDefault is the default config var ConfigDefault = Config{ - Addr: "127.0.0.1:27017", + URI: "mongodb://127.0.0.1:27017", Database: "_database", Collection: "_storage", } // Helper function to set default values func configDefault(cfg Config) Config { - if cfg.Addr == "" { - cfg.Addr = ConfigDefault.Addr + if cfg.URI == "" { + cfg.URI = ConfigDefault.URI } if cfg.Database == "" { cfg.Database = ConfigDefault.Database diff --git a/mongodb/mongodb.go b/mongodb/mongodb.go index 74bc9698..55dfebf1 100644 --- a/mongodb/mongodb.go +++ b/mongodb/mongodb.go @@ -12,7 +12,8 @@ import ( // Storage interface that is implemented by storage providers type Storage struct { - db *mongo.Collection + DB *mongo.Database + col *mongo.Collection } type MongoStorage struct { @@ -42,7 +43,6 @@ func New(config ...Config) *Storage { opt.SetDialer(cfg.Dialer) opt.SetDirect(cfg.Direct) opt.SetDisableOCSPEndpointCheck(cfg.DisableOCSPEndpointCheck) - opt.SetHeartbeatInterval(cfg.HeartbeatInterval) opt.SetHosts(cfg.Hosts) opt.SetLocalThreshold(cfg.LocalThreshold) opt.SetMaxConnIdleTime(cfg.MaxConnIdleTime) @@ -63,14 +63,28 @@ func New(config ...Config) *Storage { opt.SetZlibLevel(cfg.ZlibLevel) opt.SetZstdLevel(cfg.ZstdLevel) + // default time.Duration is not nil + // will cause panic: non-positive interval for NewTicker + if cfg.HeartbeatInterval > 0 { + opt.SetHeartbeatInterval(cfg.HeartbeatInterval) + } + // Create mongo client - client, err := mongo.NewClient(opt.ApplyURI("mongodb://" + cfg.Addr)) + client, err := mongo.NewClient(opt.ApplyURI(cfg.URI)) + if err != nil { + panic(err) + } + + ctx, cancel := context.WithTimeout(context.TODO(), 20*time.Second) + defer cancel() + err = client.Connect(ctx) if err != nil { panic(err) } // Get collection from database - db := client.Database(cfg.Database).Collection(cfg.Collection) + db := client.Database(cfg.Database) + col := db.Collection(cfg.Collection) // expired data may exist for some time beyond the 60 second period between runs of the background task. // more on https://docs.mongodb.com/manual/core/index-ttl/ @@ -82,18 +96,19 @@ func New(config ...Config) *Storage { Options: options.Index().SetExpireAfterSeconds(0), } - if _, err := db.Indexes().CreateOne(context.TODO(), indexModel); err != nil { + if _, err := col.Indexes().CreateOne(context.TODO(), indexModel); err != nil { panic(err) } return &Storage{ - db: db, + DB: db, + col: col, } } // Get value by key func (s *Storage) Get(key string) ([]byte, error) { - res := s.db.FindOne(context.TODO(), bson.M{"key": key}) + res := s.col.FindOne(context.TODO(), bson.M{"key": key}) result := MongoStorage{} if err := res.Err(); err != nil { @@ -117,17 +132,17 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { if exp != 0 { replace.Exp = time.Now().Add(exp).UTC() } - _, err := s.db.ReplaceOne(context.TODO(), filter, replace, options.Replace().SetUpsert(true)) + _, err := s.col.ReplaceOne(context.TODO(), filter, replace, options.Replace().SetUpsert(true)) return err } // Delete document by key func (s *Storage) Delete(key string) error { - _, err := s.db.DeleteOne(context.TODO(), bson.M{"key": key}) + _, err := s.col.DeleteOne(context.TODO(), bson.M{"key": key}) return err } // Clear all keys by drop collection func (s *Storage) Clear() error { - return s.db.Drop(context.TODO()) + return s.col.Drop(context.TODO()) } diff --git a/mongodb/mongodb_test.go b/mongodb/mongodb_test.go index 34650c2d..3fa02ad7 100644 --- a/mongodb/mongodb_test.go +++ b/mongodb/mongodb_test.go @@ -4,11 +4,8 @@ import ( "context" "github.com/gofiber/utils" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" "os" "testing" - "time" ) const ( @@ -18,19 +15,13 @@ const ( var uri = os.Getenv("MONGO_URI") -func Connect() (*mongo.Database, *mongo.Collection) { - client, err := mongo.NewClient(options.Client().ApplyURI(uri)) - if err != nil { - panic(err) +func getConfig() Config { + + return Config{ + URI: uri, + Database: dbName, + Collection: colName, } - - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - err = client.Connect(ctx) - - db := client.Database(dbName) - - return db, db.Collection(colName) } func contains(arr []string, item string) bool { @@ -43,10 +34,9 @@ func contains(arr []string, item string) bool { } func TestMongoStore_Set_Get(t *testing.T) { - db, col := Connect() - store := New(col) + store := New(getConfig()) defer func() { - _ = db.Client().Disconnect(context.TODO()) + _ = store.DB.Client().Disconnect(context.TODO()) }() key := "example_key" @@ -60,10 +50,9 @@ func TestMongoStore_Set_Get(t *testing.T) { } func TestMongoStore_Delete(t *testing.T) { - db, col := Connect() - store := New(col) + store := New(getConfig()) defer func() { - _ = db.Client().Disconnect(context.TODO()) + _ = store.DB.Client().Disconnect(context.TODO()) }() key := "example_key_2" @@ -78,21 +67,20 @@ func TestMongoStore_Delete(t *testing.T) { } func TestMongoStore_Clear(t *testing.T) { - db, col := Connect() - store := New(col) + store := New(getConfig()) defer func() { - _ = db.Client().Disconnect(context.TODO()) + _ = store.DB.Client().Disconnect(context.TODO()) }() key := "example_key_2" value := []byte("123") _ = store.Set(key, value, 10) - names, _ := db.ListCollectionNames(context.TODO(), bson.D{}) + names, _ := store.DB.ListCollectionNames(context.TODO(), bson.D{}) utils.AssertEqual(t, true, contains(names, colName), "has collection") _ = store.Clear() - names2, _ := db.ListCollectionNames(context.TODO(), bson.D{}) + names2, _ := store.DB.ListCollectionNames(context.TODO(), bson.D{}) utils.AssertEqual(t, false, contains(names2, colName), "do not have collection") }