Files
storage/neo4j/neo4j.go
2025-04-08 17:53:34 +02:00

221 lines
5.0 KiB
Go

package neo4j
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/auth"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
)
// Storage interface that is implemented by storage providers
type Storage struct {
db neo4j.DriverWithContext
gcInterval time.Duration
done chan struct{}
cypherMatch string
cypherMerge string
cypherDelete string
cypherReset string
cypherGC string
}
type model struct {
Key string `json:"k"`
Val []byte `json:"v"`
Exp int64 `json:"e"`
}
type neo4jConnConfig struct {
URI string
Auth auth.TokenManager
Configurations []func(*config.Config)
}
func newDriverWithContext(cfg neo4jConnConfig) (neo4j.DriverWithContext, error) {
return neo4j.NewDriverWithContext(cfg.URI, cfg.Auth, cfg.Configurations...)
}
// New creates a new storage
func New(config ...Config) *Storage {
// Set default config
cfg := configDefault(config...)
// Select db connection
var err error
db := cfg.DB
if db == nil {
db, err = newDriverWithContext(neo4jConnConfig{
URI: cfg.URI,
Auth: cfg.Auth,
Configurations: cfg.Configurations,
})
if err != nil {
log.Panicf("Unable to create connection pool: %v\n", err)
}
}
ctx := context.Background()
if err := db.VerifyConnectivity(ctx); err != nil {
log.Panicf("Unable to verify connection: %v\n", err)
}
// delete all nodes if reset set to true
if cfg.Reset {
if _, err := neo4j.ExecuteQuery(ctx, db, fmt.Sprintf("MATCH (n:%s) DELETE n FINISH", cfg.Node), nil, neo4j.EagerResultTransformer); err != nil {
if err := db.Close(ctx); err != nil {
log.Printf("Error closing storage: %v\n", err)
}
log.Panicf("Unable to reset storage: %v\n", err)
}
}
// create index on key
if _, err := neo4j.ExecuteQuery(ctx, db, fmt.Sprintf("CREATE INDEX neo4jstore_key_idx IF NOT EXISTS FOR (n:%s) ON (n.k)", cfg.Node), nil, neo4j.EagerResultTransformer); err != nil {
if err := db.Close(ctx); err != nil {
log.Printf("Error closing storage: %v\n", err)
}
log.Panicf("Unable to create index on key: %v\n", err)
}
store := &Storage{
db: db,
gcInterval: cfg.GCInterval,
done: make(chan struct{}),
cypherMatch: fmt.Sprintf("OPTIONAL MATCH (n:%s{ k: $key }) RETURN n { .* } AS data", cfg.Node),
cypherMerge: fmt.Sprintf("MERGE (n:%s{ k: $key }) SET n.v = $val, n.e = $exp FINISH", cfg.Node),
cypherDelete: fmt.Sprintf("MATCH (n:%s{ k: $key }) DELETE n FINISH", cfg.Node),
cypherReset: fmt.Sprintf("MATCH (n:%s) DELETE n FINISH", cfg.Node),
cypherGC: fmt.Sprintf("MATCH (n:%s) WHERE n.e <= $exp AND n.e != 0 DELETE n FINISH", cfg.Node),
}
go store.gcTicker()
return store
}
func (s *Storage) Get(key string) ([]byte, error) {
if len(key) <= 0 {
return nil, nil
}
ctx := context.Background()
res, err := neo4j.ExecuteQuery(
ctx, s.db, s.cypherMatch,
map[string]any{"key": key},
neo4j.EagerResultTransformer,
)
if err != nil {
return nil, err
}
data, _, _ := neo4j.GetRecordValue[map[string]any](res.Records[0], "data")
if data == nil {
return nil, nil
}
// result model
var model model
bt, _ := json.Marshal(data)
if err := json.Unmarshal(bt, &model); err != nil {
return nil, fmt.Errorf("error parsing result data: %v", err)
}
// If the expiration time has already passed, then return nil
if model.Exp != 0 && model.Exp <= time.Now().Unix() {
return nil, nil
}
return model.Val, nil
}
// Set key with value
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
if len(key) <= 0 || len(val) <= 0 {
return nil
}
var expireAt int64
if exp != 0 {
expireAt = time.Now().Add(exp).Unix()
}
// create the structure for the storage
data := model{
Key: key,
Val: val,
Exp: expireAt,
}
ctx := context.Background()
_, err := neo4j.ExecuteQuery(
ctx, s.db, s.cypherMerge,
map[string]any{"key": data.Key, "val": data.Val, "exp": data.Exp},
neo4j.EagerResultTransformer,
)
return err
}
// Delete value by key
func (s *Storage) Delete(key string) error {
if len(key) <= 0 {
return nil
}
_, err := neo4j.ExecuteQuery(context.Background(), s.db, s.cypherDelete, map[string]any{"key": key}, neo4j.EagerResultTransformer)
return err
}
// Reset all keys. Remove all nodes
func (s *Storage) Reset() error {
_, err := neo4j.ExecuteQuery(
context.Background(), s.db, s.cypherReset,
nil,
neo4j.EagerResultTransformer,
)
return err
}
// Close the database
func (s *Storage) Close() error {
s.done <- struct{}{}
return s.db.Close(context.Background())
}
// Return database client
func (s *Storage) Conn() neo4j.DriverWithContext {
return s.db
}
// gcTicker starts the gc ticker
func (s *Storage) gcTicker() {
ticker := time.NewTicker(s.gcInterval)
defer ticker.Stop()
for {
select {
case <-s.done:
return
case t := <-ticker.C:
s.gc(t)
}
}
}
// gc deletes all expired entries
func (s *Storage) gc(t time.Time) {
_, _ = neo4j.ExecuteQuery(context.Background(), s.db, s.cypherGC, map[string]any{"exp": t.Unix()}, neo4j.EagerResultTransformer)
}