diff --git a/app/api/api.go b/app/api/api.go index cf7e7169..466ca4ec 100644 --- a/app/api/api.go +++ b/app/api/api.go @@ -473,7 +473,7 @@ func (a *api) start() error { }) } - cluster, err := cluster.New(cluster.ClusterConfig{ + cluster, err := cluster.New(cluster.Config{ ID: cfg.ID, Name: cfg.Name, Path: filepath.Join(cfg.DB.Dir, "cluster"), diff --git a/cluster/api.go b/cluster/api.go index 14c671c1..54a79ca1 100644 --- a/cluster/api.go +++ b/cluster/api.go @@ -647,7 +647,7 @@ func (a *api) Lock(c echo.Context) error { a.logger.Debug().WithField("name", r.Name).Log("Acquire lock") - err := a.cluster.CreateLock(origin, r.Name, r.ValidUntil) + _, err := a.cluster.CreateLock(origin, r.Name, r.ValidUntil) if err != nil { a.logger.Debug().WithError(err).WithField("name", r.Name).Log("Unable to acquire lock") return Err(http.StatusInternalServerError, "", "unable to acquire lock: %s", err.Error()) diff --git a/cluster/cluster.go b/cluster/cluster.go index 7f67ccc0..f2a31eb9 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -67,13 +67,14 @@ type Cluster interface { SetPolicies(origin, name string, policies []iamaccess.Policy) error RemoveIdentity(origin string, name string) error - CreateLock(origin string, name string, validUntil time.Time) error + CreateLock(origin string, name string, validUntil time.Time) (*Lock, error) DeleteLock(origin string, name string) error ListLocks() map[string]time.Time SetKV(origin, key, value string) error UnsetKV(origin, key string) error GetKV(key string) (string, time.Time, error) + ListKV(prefix string) map[string]store.Value ProxyReader() proxy.ProxyReader } @@ -83,7 +84,7 @@ type Peer struct { Address string } -type ClusterConfig struct { +type Config struct { ID string // ID of the node Name string // Name of the node Path string // Path where to store all cluster data @@ -152,7 +153,7 @@ type cluster struct { var ErrDegraded = errors.New("cluster is currently degraded") -func New(config ClusterConfig) (Cluster, error) { +func New(config Config) (Cluster, error) { c := &cluster{ id: config.ID, name: config.Name, @@ -1071,13 +1072,22 @@ func (c *cluster) RemoveIdentity(origin string, name string) error { return c.applyCommand(cmd) } -func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) error { +func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) (*Lock, error) { if ok, _ := c.IsDegraded(); ok { - return ErrDegraded + return nil, ErrDegraded } if !c.IsRaftLeader() { - return c.forwarder.CreateLock(origin, name, validUntil) + err := c.forwarder.CreateLock(origin, name, validUntil) + if err != nil { + return nil, err + } + + l := &Lock{ + ValidUntil: validUntil, + } + + return l, nil } cmd := &store.Command{ @@ -1088,7 +1098,16 @@ func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) e }, } - return c.applyCommand(cmd) + err := c.applyCommand(cmd) + if err != nil { + return nil, err + } + + l := &Lock{ + ValidUntil: validUntil, + } + + return l, nil } func (c *cluster) DeleteLock(origin string, name string) error { @@ -1162,6 +1181,12 @@ func (c *cluster) GetKV(key string) (string, time.Time, error) { return value.Value, value.UpdatedAt, nil } +func (c *cluster) ListKV(prefix string) map[string]store.Value { + storeValues := c.store.ListKVS(prefix) + + return storeValues +} + func (c *cluster) applyCommand(cmd *store.Command) error { b, err := json.Marshal(cmd) if err != nil { diff --git a/cluster/kvs.go b/cluster/kvs.go new file mode 100644 index 00000000..12ce8e8c --- /dev/null +++ b/cluster/kvs.go @@ -0,0 +1,221 @@ +package cluster + +import ( + "context" + "fmt" + "io/fs" + "strings" + "sync" + "time" + + "github.com/datarhei/core/v16/cluster/store" +) + +type KVS interface { + CreateLock(name string, validUntil time.Time) (*Lock, error) + DeleteLock(name string) error + ListLocks() map[string]time.Time + + SetKV(key, value string) error + UnsetKV(key string) error + GetKV(key string) (string, time.Time, error) + ListKV(prefix string) map[string]store.Value +} + +type Lock struct { + ValidUntil time.Time + ctx context.Context + cancel context.CancelFunc + lock sync.Mutex +} + +func (l *Lock) Expired() <-chan struct{} { + l.lock.Lock() + defer l.lock.Unlock() + + if l.ctx == nil { + l.ctx, l.cancel = context.WithDeadline(context.Background(), l.ValidUntil) + + go func(l *Lock) { + <-l.ctx.Done() + + l.lock.Lock() + defer l.lock.Unlock() + if l.cancel != nil { + l.cancel() + } + }(l) + } + + return l.ctx.Done() +} + +func (l *Lock) Unlock() { + l.lock.Lock() + defer l.lock.Unlock() + + if l.cancel != nil { + l.ValidUntil = time.Now() + l.cancel() + } +} + +type clusterKVS struct { + cluster Cluster +} + +func NewClusterKVS(cluster Cluster) (KVS, error) { + s := &clusterKVS{ + cluster: cluster, + } + + return s, nil +} + +func (s *clusterKVS) CreateLock(name string, validUntil time.Time) (*Lock, error) { + return s.cluster.CreateLock("", name, validUntil) +} + +func (s *clusterKVS) DeleteLock(name string) error { + return s.cluster.DeleteLock("", name) +} + +func (s *clusterKVS) ListLocks() map[string]time.Time { + return s.cluster.ListLocks() +} + +func (s *clusterKVS) SetKV(key, value string) error { + return s.cluster.SetKV("", key, value) +} + +func (s *clusterKVS) UnsetKV(key string) error { + return s.cluster.UnsetKV("", key) +} + +func (s *clusterKVS) GetKV(key string) (string, time.Time, error) { + return s.cluster.GetKV(key) +} + +func (s *clusterKVS) ListKV(prefix string) map[string]store.Value { + return s.cluster.ListKV(prefix) +} + +type memKVS struct { + lock sync.Mutex + + locks map[string]*Lock + values map[string]store.Value +} + +func NewMemoryKVS() (KVS, error) { + return &memKVS{ + locks: map[string]*Lock{}, + values: map[string]store.Value{}, + }, nil +} + +func (s *memKVS) CreateLock(name string, validUntil time.Time) (*Lock, error) { + s.lock.Lock() + defer s.lock.Unlock() + + l, ok := s.locks[name] + + if ok { + if time.Now().Before(l.ValidUntil) { + return nil, fmt.Errorf("the lock with the name '%s' already exists", name) + } + } + + l = &Lock{ + ValidUntil: validUntil, + } + + s.locks[name] = l + + return l, nil +} + +func (s *memKVS) DeleteLock(name string) error { + s.lock.Lock() + defer s.lock.Unlock() + + lock, ok := s.locks[name] + if !ok { + return fmt.Errorf("the lock with the name '%s' doesn't exist", name) + } + + lock.Unlock() + + delete(s.locks, name) + + return nil +} + +func (s *memKVS) ListLocks() map[string]time.Time { + s.lock.Lock() + defer s.lock.Unlock() + + m := map[string]time.Time{} + + for key, lock := range s.locks { + m[key] = lock.ValidUntil + } + + return m +} + +func (s *memKVS) SetKV(key, value string) error { + s.lock.Lock() + defer s.lock.Unlock() + + v := s.values[key] + + v.Value = value + v.UpdatedAt = time.Now() + + s.values[key] = v + + return nil +} + +func (s *memKVS) UnsetKV(key string) error { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.values[key]; !ok { + return fs.ErrNotExist + } + + delete(s.values, key) + + return nil +} + +func (s *memKVS) GetKV(key string) (string, time.Time, error) { + s.lock.Lock() + defer s.lock.Unlock() + + v, ok := s.values[key] + if !ok { + return "", time.Time{}, fs.ErrNotExist + } + + return v.Value, v.UpdatedAt, nil +} + +func (s *memKVS) ListKV(prefix string) map[string]store.Value { + s.lock.Lock() + defer s.lock.Unlock() + + m := map[string]store.Value{} + + for key, value := range s.values { + if !strings.HasPrefix(key, prefix) { + continue + } + + m[key] = value + } + + return m +} diff --git a/cluster/kvs_test.go b/cluster/kvs_test.go new file mode 100644 index 00000000..900cb38f --- /dev/null +++ b/cluster/kvs_test.go @@ -0,0 +1,143 @@ +package cluster + +import ( + "io/fs" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSetGetUnsetValue(t *testing.T) { + kvs, err := NewMemoryKVS() + require.NoError(t, err) + + _, _, err = kvs.GetKV("foo") + require.Error(t, err) + + err = kvs.SetKV("foo", "bar") + require.NoError(t, err) + + value, _, err := kvs.GetKV("foo") + require.NoError(t, err) + require.Equal(t, "bar", value) + + err = kvs.UnsetKV("foo") + require.NoError(t, err) + + _, _, err = kvs.GetKV("foo") + require.Error(t, err) +} + +func TestKeyNotFound(t *testing.T) { + kvs, err := NewMemoryKVS() + require.NoError(t, err) + + _, _, err = kvs.GetKV("foo") + require.ErrorIs(t, err, fs.ErrNotExist) + + err = kvs.UnsetKV("foo") + require.ErrorIs(t, err, fs.ErrNotExist) +} + +func TestListKV(t *testing.T) { + kvs, err := NewMemoryKVS() + require.NoError(t, err) + + err = kvs.SetKV("foo", "bar") + require.NoError(t, err) + + err = kvs.SetKV("foz", "baz") + require.NoError(t, err) + + err = kvs.SetKV("bar", "foo") + require.NoError(t, err) + + list := kvs.ListKV("") + require.Equal(t, 3, len(list)) + + list = kvs.ListKV("f") + require.Equal(t, 2, len(list)) + + list = kvs.ListKV("b") + require.Equal(t, 1, len(list)) + + list = kvs.ListKV("fo") + require.Equal(t, 2, len(list)) + + list = kvs.ListKV("foo") + require.Equal(t, 1, len(list)) +} + +func TestLock(t *testing.T) { + kvs, err := NewMemoryKVS() + require.NoError(t, err) + + until := time.Now().Add(5 * time.Second) + + lock, err := kvs.CreateLock("foobar", until) + require.NoError(t, err) + require.Equal(t, until, lock.ValidUntil) + + require.Eventually(t, func() bool { + select { + case <-lock.Expired(): + return true + case <-time.After(10 * time.Millisecond): + return false + } + }, 10*time.Second, time.Second) +} + +func TestLockCreate(t *testing.T) { + kvs, err := NewMemoryKVS() + require.NoError(t, err) + + until := time.Now().Add(5 * time.Second) + lock, err := kvs.CreateLock("foobar", until) + require.NoError(t, err) + require.Equal(t, until, lock.ValidUntil) + + _, err = kvs.CreateLock("foobar", until) + require.Error(t, err) + + err = kvs.DeleteLock("foobar") + require.NoError(t, err) +} + +func TestLockDelete(t *testing.T) { + kvs, err := NewMemoryKVS() + require.NoError(t, err) + + err = kvs.DeleteLock("foobar") + require.Error(t, err) + + until := time.Now().Add(5 * time.Second) + _, err = kvs.CreateLock("foobar", until) + require.NoError(t, err) + + err = kvs.DeleteLock("foobar") + require.NoError(t, err) +} + +func TestLocksList(t *testing.T) { + kvs, err := NewMemoryKVS() + require.NoError(t, err) + + list := kvs.ListLocks() + require.Empty(t, list) + + until := time.Now().Add(5 * time.Second) + _, err = kvs.CreateLock("foobar", until) + require.NoError(t, err) + + list = kvs.ListLocks() + require.NotEmpty(t, list) + require.Equal(t, list["foobar"], until) + + err = kvs.DeleteLock("foobar") + require.NoError(t, err) + + list = kvs.ListLocks() + require.Empty(t, list) +} diff --git a/cluster/tls.go b/cluster/tls.go new file mode 100644 index 00000000..d1f42143 --- /dev/null +++ b/cluster/tls.go @@ -0,0 +1,168 @@ +package cluster + +import ( + "context" + "encoding/base64" + "fmt" + "io/fs" + "path" + "strings" + "sync" + "time" + + "github.com/caddyserver/certmagic" +) + +type clusterStorage struct { + kvs KVS + prefix string + locks map[string]*Lock + muLocks sync.Mutex +} + +func NewClusterStorage(kvs KVS, prefix string) (certmagic.Storage, error) { + s := &clusterStorage{ + kvs: kvs, + prefix: prefix, + locks: map[string]*Lock{}, + } + + return s, nil +} + +func (s *clusterStorage) prefixKey(key string) string { + return path.Join(s.prefix, key) +} + +func (s *clusterStorage) unprefixKey(key string) string { + return strings.TrimPrefix(key, s.prefix+"/") +} + +func (s *clusterStorage) Lock(ctx context.Context, name string) error { + for { + lock, err := s.kvs.CreateLock(s.prefixKey(name), time.Now().Add(time.Minute)) + if err == nil { + go func() { + <-lock.Expired() + s.Unlock(context.Background(), name) + }() + s.muLocks.Lock() + s.locks[name] = lock + s.muLocks.Unlock() + + return nil + } + + select { + case <-time.After(time.Second): + continue + case <-ctx.Done(): + return ctx.Err() + case <-time.After(5 * time.Minute): + return fmt.Errorf("wasn't able to acquire lock") + } + } +} + +func (s *clusterStorage) Unlock(ctx context.Context, name string) error { + err := s.kvs.DeleteLock(s.prefixKey(name)) + if err != nil { + return err + } + + s.muLocks.Lock() + delete(s.locks, name) + s.muLocks.Unlock() + + return nil +} + +// Store puts value at key. +func (s *clusterStorage) Store(ctx context.Context, key string, value []byte) error { + encodedValue := base64.StdEncoding.EncodeToString(value) + return s.kvs.SetKV(s.prefixKey(key), encodedValue) +} + +// Load retrieves the value at key. +func (s *clusterStorage) Load(ctx context.Context, key string) ([]byte, error) { + encodedValue, _, err := s.kvs.GetKV(s.prefixKey(key)) + if err != nil { + return nil, err + } + + return base64.StdEncoding.DecodeString(encodedValue) +} + +// Delete deletes key. An error should be +// returned only if the key still exists +// when the method returns. +func (s *clusterStorage) Delete(ctx context.Context, key string) error { + return s.kvs.UnsetKV(s.prefixKey(key)) +} + +// Exists returns true if the key exists +// and there was no error checking. +func (s *clusterStorage) Exists(ctx context.Context, key string) bool { + _, _, err := s.kvs.GetKV(s.prefixKey(key)) + return err == nil +} + +// List returns all keys that match prefix. +// If recursive is true, non-terminal keys +// will be enumerated (i.e. "directories" +// should be walked); otherwise, only keys +// prefixed exactly by prefix will be listed. +func (s *clusterStorage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) { + values := s.kvs.ListKV(s.prefixKey(prefix)) + + keys := []string{} + + for key := range values { + keys = append(keys, s.unprefixKey(key)) + } + + if len(keys) == 0 { + return nil, fs.ErrNotExist + } + + if recursive { + return keys, nil + } + + prefix = strings.TrimSuffix(prefix, "/") + + keyMap := map[string]struct{}{} + for _, key := range keys { + elms := strings.Split(strings.TrimPrefix(key, prefix+"/"), "/") + keyMap[elms[0]] = struct{}{} + } + + keys = []string{} + for key := range keyMap { + keys = append(keys, path.Join(prefix, key)) + } + + return keys, nil +} + +// Stat returns information about key. +func (s *clusterStorage) Stat(ctx context.Context, key string) (certmagic.KeyInfo, error) { + encodedValue, lastModified, err := s.kvs.GetKV(s.prefixKey(key)) + if err != nil { + return certmagic.KeyInfo{}, err + } + + value, err := base64.StdEncoding.DecodeString(encodedValue) + if err != nil { + return certmagic.KeyInfo{}, err + } + + info := certmagic.KeyInfo{ + Key: key, + Modified: lastModified, + Size: int64(len(value)), + IsTerminal: false, + } + + return info, nil +} diff --git a/cluster/tls_test.go b/cluster/tls_test.go new file mode 100644 index 00000000..78837946 --- /dev/null +++ b/cluster/tls_test.go @@ -0,0 +1,164 @@ +package cluster + +import ( + "context" + "io/fs" + "path" + "testing" + "time" + + "github.com/caddyserver/certmagic" + "github.com/stretchr/testify/require" +) + +func setupStorage() (certmagic.Storage, error) { + kvs, err := NewMemoryKVS() + if err != nil { + return nil, err + } + + return NewClusterStorage(kvs, "some_prefix") +} + +func TestStorageStore(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.crt"), []byte("crt data")) + require.NoError(t, err) +} + +func TestStorageExists(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt") + + err = cs.Store(context.Background(), key, []byte("crt data")) + require.NoError(t, err) + + exists := cs.Exists(context.Background(), key) + require.True(t, exists) +} + +func TestStorageLoad(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt") + content := []byte("crt data") + + err = cs.Store(context.Background(), key, content) + require.NoError(t, err) + + contentLoded, err := cs.Load(context.Background(), key) + require.NoError(t, err) + + require.Equal(t, content, contentLoded) +} + +func TestStorageDelete(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt") + content := []byte("crt data") + + err = cs.Store(context.Background(), key, content) + require.NoError(t, err) + + err = cs.Delete(context.Background(), key) + require.NoError(t, err) + + exists := cs.Exists(context.Background(), key) + require.False(t, exists) + + contentLoaded, err := cs.Load(context.Background(), key) + require.Nil(t, contentLoaded) + require.ErrorIs(t, err, fs.ErrNotExist) +} + +func TestStorageStat(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt") + content := []byte("crt data") + + err = cs.Store(context.Background(), key, content) + require.NoError(t, err) + + info, err := cs.Stat(context.Background(), key) + require.NoError(t, err) + + require.Equal(t, key, info.Key) +} + +func TestStorageList(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.crt"), []byte("crt")) + require.NoError(t, err) + err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.key"), []byte("key")) + require.NoError(t, err) + err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.json"), []byte("meta")) + require.NoError(t, err) + + keys, err := cs.List(context.Background(), path.Join("acme", "example.com", "sites", "example.com"), true) + require.NoError(t, err) + require.Len(t, keys, 3) + require.Contains(t, keys, path.Join("acme", "example.com", "sites", "example.com", "example.com.crt")) +} + +func TestStorageListNonRecursive(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.crt"), []byte("crt")) + require.NoError(t, err) + err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.key"), []byte("key")) + require.NoError(t, err) + err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.json"), []byte("meta")) + require.NoError(t, err) + + keys, err := cs.List(context.Background(), path.Join("acme", "example.com", "sites"), false) + require.NoError(t, err) + + require.Len(t, keys, 1) + require.Contains(t, keys, path.Join("acme", "example.com", "sites", "example.com")) +} + +func TestStorageLockUnlock(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + lockKey := path.Join("acme", "example.com", "sites", "example.com", "lock") + + err = cs.Lock(context.Background(), lockKey) + require.NoError(t, err) + + err = cs.Unlock(context.Background(), lockKey) + require.NoError(t, err) +} + +func TestStorageTwoLocks(t *testing.T) { + cs, err := setupStorage() + require.NoError(t, err) + + lockKey := path.Join("acme", "example.com", "sites", "example.com", "lock") + + err = cs.Lock(context.Background(), lockKey) + require.NoError(t, err) + + go time.AfterFunc(5*time.Second, func() { + err := cs.Unlock(context.Background(), lockKey) + require.NoError(t, err) + }) + + err = cs.Lock(context.Background(), lockKey) + require.NoError(t, err) + + err = cs.Unlock(context.Background(), lockKey) + require.NoError(t, err) +}