diff --git a/neo4j/config.go b/neo4j/config.go index 8fa67550..31473ac5 100644 --- a/neo4j/config.go +++ b/neo4j/config.go @@ -1,4 +1,4 @@ -package neo4jstore +package neo4j import ( "time" @@ -48,12 +48,12 @@ type Config struct { // Optional. Default is "fiber_storage" Node string - // Reset clears any existing keys in existing Table + // Reset clears any existing keys (Nodes) // // Optional. Default is false Reset bool - // Time before deleting expired keys + // Time before deleting expired keys (Nodes) // // Optional. Default is 10 * time.Second GCInterval time.Duration diff --git a/neo4j/neo4j.go b/neo4j/neo4j.go index b6c026dc..9c0c47c3 100644 --- a/neo4j/neo4j.go +++ b/neo4j/neo4j.go @@ -1,13 +1,12 @@ -package neo4jstore +package neo4j import ( "context" "encoding/json" "fmt" - "os" + "log" "time" - "github.com/gofiber/utils/v2" "github.com/neo4j/neo4j-go-driver/v5/neo4j" "github.com/neo4j/neo4j-go-driver/v5/neo4j/auth" "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" @@ -28,7 +27,7 @@ type Storage struct { type model struct { Key string `json:"k"` - Val string `json:"v"` + Val []byte `json:"v"` Exp int64 `json:"e"` } @@ -57,28 +56,28 @@ func New(config ...Config) *Storage { Configurations: cfg.Configurations, }) if err != nil { - fmt.Fprintf(os.Stderr, "Unable to create connection pool: %v\n", err) + log.Panicf("Unable to create connection pool: %v\n", err) } } ctx := context.Background() if err := db.VerifyConnectivity(ctx); err != nil { - panic(err) + log.Panicf("Unable to verify connection: %v\n", err) } - // truncate node if reset set to true + // delete all nodes if reset set to true if cfg.Reset { if _, err := neo4j.ExecuteQuery(ctx, db, fmt.Sprintf("MATCH (n:%s) DELETE n FINISH", cfg.Node), nil, neo4j.EagerResultTransformer); err != nil { db.Close(ctx) - panic(err) + log.Panicf("Unable to reset storage: %v\n", err) } } // create index on key if _, err := neo4j.ExecuteQuery(ctx, db, fmt.Sprintf("CREATE INDEX neo4jstore_key_idx IF NOT EXISTS FOR (n:%s) ON (n.k)", cfg.Node), nil, neo4j.EagerResultTransformer); err != nil { db.Close(ctx) - panic(err) + log.Panicf("Unable to create index on key: %v\n", err) } store := &Storage{ @@ -122,14 +121,18 @@ func (s *Storage) Get(key string) ([]byte, error) { // result model var model model - mapToStruct(data, &model) + bt, _ := json.Marshal(data) + + if err := json.Unmarshal(bt, &model); err != nil { + return nil, fmt.Errorf("error parsing result data: %v", err) + } // If the expiration time has already passed, then return nil if model.Exp != 0 && model.Exp <= time.Now().Unix() { return nil, nil } - return utils.UnsafeBytes(model.Val), nil + return model.Val, nil } // Set key with value @@ -142,12 +145,10 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { expireAt = time.Now().Add(exp).Unix() } - valStr := utils.UnsafeString(val) - // create the structure for the storage data := model{ Key: key, - Val: valStr, + Val: val, Exp: expireAt, } @@ -213,11 +214,3 @@ func (s *Storage) gcTicker() { func (s *Storage) gc(t time.Time) { _, _ = neo4j.ExecuteQuery(context.Background(), s.db, s.cypherGC, map[string]any{"exp": t.Unix()}, neo4j.EagerResultTransformer) } - -func mapToStruct(src map[string]any, dest any) { - bt, _ := json.Marshal(src) - - if err := json.Unmarshal(bt, dest); err != nil { - panic(err) - } -} diff --git a/neo4j/neo4j_test.go b/neo4j/neo4j_test.go index edc63a84..71e3be95 100644 --- a/neo4j/neo4j_test.go +++ b/neo4j/neo4j_test.go @@ -1,9 +1,8 @@ -package neo4jstore +package neo4j import ( "context" "log" - "os" "testing" "time" @@ -11,15 +10,14 @@ import ( "github.com/testcontainers/testcontainers-go/modules/neo4j" ) -var testStore *Storage +const neo4jImgVer string = "neo4j:5.26" -// TestMain sets up and tears down the test container -func TestMain(m *testing.M) { +func startContainer() (*Storage, func()) { ctx := context.Background() // Start a Neo4j test container neo4jContainer, err := neo4j.Run(ctx, - "neo4j:5.26", + neo4jImgVer, neo4j.WithAdminPassword("pass#w*#d"), ) if err != nil { @@ -40,21 +38,23 @@ func TestMain(m *testing.M) { Password: "pass#w*#d", }) - testStore = store + return store, func() { + store.Close() - defer testStore.Close() - defer func() { if err := neo4jContainer.Terminate(ctx); err != nil { log.Printf("Failed to terminate Neo4j container: %v", err) } - }() + } - code := m.Run() - - os.Exit(code) } func Test_Neo4jStore_Set(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + var ( key = "john" val = []byte("doe") @@ -65,6 +65,12 @@ func Test_Neo4jStore_Set(t *testing.T) { } func Test_Neo4jStore_Upsert(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + var ( key = "john" val = []byte("doe") @@ -78,6 +84,12 @@ func Test_Neo4jStore_Upsert(t *testing.T) { } func Test_Neo4jStore_Get(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + var ( key = "john" val = []byte("doe") @@ -92,6 +104,12 @@ func Test_Neo4jStore_Get(t *testing.T) { } func Test_Neo4jStore_Set_Expiration(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + var ( key = "john" val = []byte("doe") @@ -109,6 +127,12 @@ func Test_Neo4jStore_Set_Expiration(t *testing.T) { } func Test_Neo4jStore_Get_Expired(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + key := "john" result, err := testStore.Get(key) @@ -117,12 +141,24 @@ func Test_Neo4jStore_Get_Expired(t *testing.T) { } func Test_Neo4jStore_Get_NotExist(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + result, err := testStore.Get("notexist") require.NoError(t, err) require.Zero(t, len(result)) } func Test_Neo4jStore_Delete(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + var ( key = "john" val = []byte("doe") @@ -140,6 +176,12 @@ func Test_Neo4jStore_Delete(t *testing.T) { } func Test_Neo4jStore_Reset(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + val := []byte("doe") err := testStore.Set("john1", val, 0) @@ -161,6 +203,12 @@ func Test_Neo4jStore_Reset(t *testing.T) { } func Test_Neo4jStore_Non_UTF8(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + val := []byte("0xF5") err := testStore.Set("0xF6", val, 0) @@ -172,14 +220,29 @@ func Test_Neo4jStore_Non_UTF8(t *testing.T) { } func Test_Neo4jStore_Close(t *testing.T) { + t.Parallel() + + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + require.Nil(t, testStore.Close()) } func Test_Neo4jStore_Conn(t *testing.T) { + t.Parallel() + testStore, cleanup := startContainer() + + t.Cleanup(cleanup) + require.True(t, testStore.Conn() != nil) } func Benchmark_Neo4jStore_Set(b *testing.B) { + testStore, cleanup := startContainer() + + b.Cleanup(cleanup) + b.ReportAllocs() b.ResetTimer() @@ -192,6 +255,10 @@ func Benchmark_Neo4jStore_Set(b *testing.B) { } func Benchmark_Neo4jStore_Get(b *testing.B) { + testStore, cleanup := startContainer() + + b.Cleanup(cleanup) + err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) @@ -206,6 +273,10 @@ func Benchmark_Neo4jStore_Get(b *testing.B) { } func Benchmark_Neo4jStore_SetAndDelete(b *testing.B) { + testStore, cleanup := startContainer() + + b.Cleanup(cleanup) + b.ReportAllocs() b.ResetTimer()