Files
storage/nats/nats.go
2024-07-07 21:00:09 -04:00

325 lines
6.2 KiB
Go

package nats
import (
"bytes"
"context"
"encoding/gob"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
)
// Storage interface that is implemented by storage providers
type Storage struct {
nc *nats.Conn
kv jetstream.KeyValue
err error
ctx context.Context
cfg Config
mu sync.RWMutex
}
type entry struct {
Data []byte
Expiry int64
}
func init() {
gob.Register(entry{})
}
// connectHandler is a helper function to set the initial connect handler
func (s *Storage) connectHandler(nc *nats.Conn) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
s.kv, err = newNatsKV(
nc,
s.ctx,
s.cfg.KeyValueConfig,
)
if err != nil {
s.err = errors.Join(s.err, err)
}
}
// disconnectErrHandler is a helper function to set the disconnect error handler
func (s *Storage) disconnectErrHandler(nc *nats.Conn, err error) {
s.mu.Lock()
defer s.mu.Unlock()
nc.Opts.RetryOnFailedConnect = true
if err != nil {
s.err = errors.Join(s.err, err)
}
}
// reconnectHandler is a helper function to set the reconnect handler
func (s *Storage) reconnectHandler(nc *nats.Conn) {
s.connectHandler(nc)
}
// errorHandler is a helper function to set the error handler
func (s *Storage) errorHandler(nc *nats.Conn, sub *nats.Subscription, err error) {
s.mu.Lock()
defer s.mu.Unlock()
if err != nil {
s.err = errors.Join(s.err, fmt.Errorf("subject %q: %w", sub.Subject, err))
}
}
func newNatsKV(nc *nats.Conn, ctx context.Context, keyValueConfig jetstream.KeyValueConfig) (jetstream.KeyValue, error) {
js, err := jetstream.New(nc)
if err != nil {
return nil, fmt.Errorf("get jetstream: %w", err)
}
jskv, err := js.KeyValue(ctx, keyValueConfig.Bucket)
if err != nil {
if errors.Is(err, jetstream.ErrBucketNotFound) {
jskv, err = js.CreateKeyValue(ctx, keyValueConfig)
if err != nil {
return nil, fmt.Errorf("jetstream: create kv: %w", err)
}
} else {
return nil, fmt.Errorf("jetstream: get kv: %w", err)
}
}
return jskv, nil
}
// Process the url string argument to Connect.
// Return an array of urls, even if only one.
func processUrlString(url string) []string {
urls := strings.Split(url, ",")
var j int
for _, s := range urls {
u := strings.TrimSpace(s)
if len(u) > 0 {
urls[j] = u
j++
}
}
return urls[:j]
}
// New creates a new nats kv storage
func New(config ...Config) *Storage {
// Set default config
cfg := configDefault(config...)
storage := &Storage{
ctx: cfg.Context,
cfg: cfg,
}
// Set the nats options with default custom handlers
cfg.NatsOptions = append(
[]nats.Option{
nats.ConnectHandler(storage.connectHandler),
nats.DisconnectErrHandler(storage.disconnectErrHandler),
nats.ReconnectHandler(storage.reconnectHandler),
nats.ErrorHandler(storage.errorHandler),
},
cfg.NatsOptions...,
)
natsOpts := nats.GetDefaultOptions()
natsOpts.Servers = processUrlString(cfg.URLs)
for _, opt := range cfg.NatsOptions {
if opt != nil {
if err := opt(&natsOpts); err != nil {
panic(err)
}
}
}
// Connect to NATS
var err error
storage.nc, err = natsOpts.Connect()
if opErr, ok := err.(*net.OpError); ok && natsOpts.RetryOnFailedConnect {
if opErr.Op != "dial" {
panic(err)
}
} else if err != nil {
panic(err)
}
// TODO improve this crude way to wait for the connection to be established
time.Sleep(cfg.WaitForConnection)
return storage
}
// Get value by key
func (s *Storage) Get(key string) ([]byte, error) {
if len(key) <= 0 {
return nil, nil
}
s.mu.RLock()
kv := s.kv
s.mu.RUnlock()
if kv == nil {
return nil, fmt.Errorf("kv not initialized: %v", s.err)
}
v, err := kv.Get(s.ctx, key)
if err != nil {
if errors.Is(err, jetstream.ErrKeyNotFound) {
return nil, nil
}
return nil, fmt.Errorf("get: %w", err)
}
e := entry{}
err = gob.NewDecoder(
bytes.NewBuffer(v.Value())).
Decode(&e)
if err != nil || e.Expiry <= time.Now().Unix() {
_ = kv.Delete(s.ctx, key)
return nil, nil
}
return e.Data, nil
}
// Set key with value
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
if len(key) <= 0 || len(val) <= 0 {
return nil
}
s.mu.RLock()
kv := s.kv
s.mu.RUnlock()
if kv == nil {
return fmt.Errorf("kv not initialized: %v", s.err)
}
// expiry
var expSeconds int64
if exp != 0 {
expSeconds = time.Now().Add(exp).Unix()
}
// encode
e := new(bytes.Buffer)
err := gob.NewEncoder(e).Encode(entry{
Data: val,
Expiry: expSeconds,
})
if err != nil {
return fmt.Errorf("encode: %w", err)
}
// set
_, err = kv.Put(s.ctx, key, e.Bytes())
if errors.Is(err, jetstream.ErrKeyNotFound) {
_, err := kv.Create(s.ctx, key, e.Bytes())
if err != nil {
return fmt.Errorf("create: %w", err)
}
}
return err
}
// Delete key by key
func (s *Storage) Delete(key string) error {
if len(key) <= 0 {
return nil
}
s.mu.RLock()
kv := s.kv
s.mu.RUnlock()
if kv == nil {
return fmt.Errorf("kv not initialized: %v", s.err)
}
return kv.Delete(s.ctx, key)
}
// Reset all keys
func (s *Storage) Reset() error {
js, err := jetstream.New(s.nc)
if err != nil {
return fmt.Errorf("get jetstream: %w", err)
}
// Delete the bucket
err = js.DeleteKeyValue(s.ctx, s.cfg.KeyValueConfig.Bucket)
if err != nil {
return fmt.Errorf("delete kv: %w", err)
}
// Create the bucket
s.mu.Lock()
defer s.mu.Unlock()
s.kv, err = newNatsKV(
s.nc,
s.ctx,
s.cfg.KeyValueConfig,
)
if err != nil {
s.err = errors.Join(err)
return err
}
s.err = nil
return nil
}
// Close the nats connection
func (s *Storage) Close() error {
s.mu.RLock()
s.nc.Close()
s.mu.RUnlock()
return nil
}
// Return database client
func (s *Storage) Conn() (*nats.Conn, jetstream.KeyValue) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.nc, s.kv
}
// Return all the keys
func (s *Storage) Keys() ([]string, error) {
s.mu.RLock()
kv := s.kv
s.mu.RUnlock()
if kv == nil {
return nil, fmt.Errorf("kv not initialized: %v", s.err)
}
keyLister, err := kv.ListKeys(s.ctx)
if err != nil {
return nil, fmt.Errorf("keys: %w", err)
}
var keys []string
for key := range keyLister.Keys() {
keys = append(keys, key)
}
_ = keyLister.Stop()
// Double check if no valid keys were found
if len(keys) == 0 {
return nil, nil
}
return keys, nil
}