Implement certmagic.Storage on cluster

This commit is contained in:
Ingo Oppermann
2023-06-23 21:00:45 +02:00
parent fc49c97a9f
commit f37896a1e3
7 changed files with 730 additions and 9 deletions

View File

@@ -473,7 +473,7 @@ func (a *api) start() error {
}) })
} }
cluster, err := cluster.New(cluster.ClusterConfig{ cluster, err := cluster.New(cluster.Config{
ID: cfg.ID, ID: cfg.ID,
Name: cfg.Name, Name: cfg.Name,
Path: filepath.Join(cfg.DB.Dir, "cluster"), Path: filepath.Join(cfg.DB.Dir, "cluster"),

View File

@@ -647,7 +647,7 @@ func (a *api) Lock(c echo.Context) error {
a.logger.Debug().WithField("name", r.Name).Log("Acquire lock") a.logger.Debug().WithField("name", r.Name).Log("Acquire lock")
err := a.cluster.CreateLock(origin, r.Name, r.ValidUntil) _, err := a.cluster.CreateLock(origin, r.Name, r.ValidUntil)
if err != nil { if err != nil {
a.logger.Debug().WithError(err).WithField("name", r.Name).Log("Unable to acquire lock") a.logger.Debug().WithError(err).WithField("name", r.Name).Log("Unable to acquire lock")
return Err(http.StatusInternalServerError, "", "unable to acquire lock: %s", err.Error()) return Err(http.StatusInternalServerError, "", "unable to acquire lock: %s", err.Error())

View File

@@ -67,13 +67,14 @@ type Cluster interface {
SetPolicies(origin, name string, policies []iamaccess.Policy) error SetPolicies(origin, name string, policies []iamaccess.Policy) error
RemoveIdentity(origin string, name string) error RemoveIdentity(origin string, name string) error
CreateLock(origin string, name string, validUntil time.Time) error CreateLock(origin string, name string, validUntil time.Time) (*Lock, error)
DeleteLock(origin string, name string) error DeleteLock(origin string, name string) error
ListLocks() map[string]time.Time ListLocks() map[string]time.Time
SetKV(origin, key, value string) error SetKV(origin, key, value string) error
UnsetKV(origin, key string) error UnsetKV(origin, key string) error
GetKV(key string) (string, time.Time, error) GetKV(key string) (string, time.Time, error)
ListKV(prefix string) map[string]store.Value
ProxyReader() proxy.ProxyReader ProxyReader() proxy.ProxyReader
} }
@@ -83,7 +84,7 @@ type Peer struct {
Address string Address string
} }
type ClusterConfig struct { type Config struct {
ID string // ID of the node ID string // ID of the node
Name string // Name of the node Name string // Name of the node
Path string // Path where to store all cluster data Path string // Path where to store all cluster data
@@ -152,7 +153,7 @@ type cluster struct {
var ErrDegraded = errors.New("cluster is currently degraded") var ErrDegraded = errors.New("cluster is currently degraded")
func New(config ClusterConfig) (Cluster, error) { func New(config Config) (Cluster, error) {
c := &cluster{ c := &cluster{
id: config.ID, id: config.ID,
name: config.Name, name: config.Name,
@@ -1071,13 +1072,22 @@ func (c *cluster) RemoveIdentity(origin string, name string) error {
return c.applyCommand(cmd) return c.applyCommand(cmd)
} }
func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) error { func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) (*Lock, error) {
if ok, _ := c.IsDegraded(); ok { if ok, _ := c.IsDegraded(); ok {
return ErrDegraded return nil, ErrDegraded
} }
if !c.IsRaftLeader() { if !c.IsRaftLeader() {
return c.forwarder.CreateLock(origin, name, validUntil) err := c.forwarder.CreateLock(origin, name, validUntil)
if err != nil {
return nil, err
}
l := &Lock{
ValidUntil: validUntil,
}
return l, nil
} }
cmd := &store.Command{ cmd := &store.Command{
@@ -1088,7 +1098,16 @@ func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) e
}, },
} }
return c.applyCommand(cmd) err := c.applyCommand(cmd)
if err != nil {
return nil, err
}
l := &Lock{
ValidUntil: validUntil,
}
return l, nil
} }
func (c *cluster) DeleteLock(origin string, name string) error { func (c *cluster) DeleteLock(origin string, name string) error {
@@ -1162,6 +1181,12 @@ func (c *cluster) GetKV(key string) (string, time.Time, error) {
return value.Value, value.UpdatedAt, nil return value.Value, value.UpdatedAt, nil
} }
func (c *cluster) ListKV(prefix string) map[string]store.Value {
storeValues := c.store.ListKVS(prefix)
return storeValues
}
func (c *cluster) applyCommand(cmd *store.Command) error { func (c *cluster) applyCommand(cmd *store.Command) error {
b, err := json.Marshal(cmd) b, err := json.Marshal(cmd)
if err != nil { if err != nil {

221
cluster/kvs.go Normal file
View File

@@ -0,0 +1,221 @@
package cluster
import (
"context"
"fmt"
"io/fs"
"strings"
"sync"
"time"
"github.com/datarhei/core/v16/cluster/store"
)
type KVS interface {
CreateLock(name string, validUntil time.Time) (*Lock, error)
DeleteLock(name string) error
ListLocks() map[string]time.Time
SetKV(key, value string) error
UnsetKV(key string) error
GetKV(key string) (string, time.Time, error)
ListKV(prefix string) map[string]store.Value
}
type Lock struct {
ValidUntil time.Time
ctx context.Context
cancel context.CancelFunc
lock sync.Mutex
}
func (l *Lock) Expired() <-chan struct{} {
l.lock.Lock()
defer l.lock.Unlock()
if l.ctx == nil {
l.ctx, l.cancel = context.WithDeadline(context.Background(), l.ValidUntil)
go func(l *Lock) {
<-l.ctx.Done()
l.lock.Lock()
defer l.lock.Unlock()
if l.cancel != nil {
l.cancel()
}
}(l)
}
return l.ctx.Done()
}
func (l *Lock) Unlock() {
l.lock.Lock()
defer l.lock.Unlock()
if l.cancel != nil {
l.ValidUntil = time.Now()
l.cancel()
}
}
type clusterKVS struct {
cluster Cluster
}
func NewClusterKVS(cluster Cluster) (KVS, error) {
s := &clusterKVS{
cluster: cluster,
}
return s, nil
}
func (s *clusterKVS) CreateLock(name string, validUntil time.Time) (*Lock, error) {
return s.cluster.CreateLock("", name, validUntil)
}
func (s *clusterKVS) DeleteLock(name string) error {
return s.cluster.DeleteLock("", name)
}
func (s *clusterKVS) ListLocks() map[string]time.Time {
return s.cluster.ListLocks()
}
func (s *clusterKVS) SetKV(key, value string) error {
return s.cluster.SetKV("", key, value)
}
func (s *clusterKVS) UnsetKV(key string) error {
return s.cluster.UnsetKV("", key)
}
func (s *clusterKVS) GetKV(key string) (string, time.Time, error) {
return s.cluster.GetKV(key)
}
func (s *clusterKVS) ListKV(prefix string) map[string]store.Value {
return s.cluster.ListKV(prefix)
}
type memKVS struct {
lock sync.Mutex
locks map[string]*Lock
values map[string]store.Value
}
func NewMemoryKVS() (KVS, error) {
return &memKVS{
locks: map[string]*Lock{},
values: map[string]store.Value{},
}, nil
}
func (s *memKVS) CreateLock(name string, validUntil time.Time) (*Lock, error) {
s.lock.Lock()
defer s.lock.Unlock()
l, ok := s.locks[name]
if ok {
if time.Now().Before(l.ValidUntil) {
return nil, fmt.Errorf("the lock with the name '%s' already exists", name)
}
}
l = &Lock{
ValidUntil: validUntil,
}
s.locks[name] = l
return l, nil
}
func (s *memKVS) DeleteLock(name string) error {
s.lock.Lock()
defer s.lock.Unlock()
lock, ok := s.locks[name]
if !ok {
return fmt.Errorf("the lock with the name '%s' doesn't exist", name)
}
lock.Unlock()
delete(s.locks, name)
return nil
}
func (s *memKVS) ListLocks() map[string]time.Time {
s.lock.Lock()
defer s.lock.Unlock()
m := map[string]time.Time{}
for key, lock := range s.locks {
m[key] = lock.ValidUntil
}
return m
}
func (s *memKVS) SetKV(key, value string) error {
s.lock.Lock()
defer s.lock.Unlock()
v := s.values[key]
v.Value = value
v.UpdatedAt = time.Now()
s.values[key] = v
return nil
}
func (s *memKVS) UnsetKV(key string) error {
s.lock.Lock()
defer s.lock.Unlock()
if _, ok := s.values[key]; !ok {
return fs.ErrNotExist
}
delete(s.values, key)
return nil
}
func (s *memKVS) GetKV(key string) (string, time.Time, error) {
s.lock.Lock()
defer s.lock.Unlock()
v, ok := s.values[key]
if !ok {
return "", time.Time{}, fs.ErrNotExist
}
return v.Value, v.UpdatedAt, nil
}
func (s *memKVS) ListKV(prefix string) map[string]store.Value {
s.lock.Lock()
defer s.lock.Unlock()
m := map[string]store.Value{}
for key, value := range s.values {
if !strings.HasPrefix(key, prefix) {
continue
}
m[key] = value
}
return m
}

143
cluster/kvs_test.go Normal file
View File

@@ -0,0 +1,143 @@
package cluster
import (
"io/fs"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSetGetUnsetValue(t *testing.T) {
kvs, err := NewMemoryKVS()
require.NoError(t, err)
_, _, err = kvs.GetKV("foo")
require.Error(t, err)
err = kvs.SetKV("foo", "bar")
require.NoError(t, err)
value, _, err := kvs.GetKV("foo")
require.NoError(t, err)
require.Equal(t, "bar", value)
err = kvs.UnsetKV("foo")
require.NoError(t, err)
_, _, err = kvs.GetKV("foo")
require.Error(t, err)
}
func TestKeyNotFound(t *testing.T) {
kvs, err := NewMemoryKVS()
require.NoError(t, err)
_, _, err = kvs.GetKV("foo")
require.ErrorIs(t, err, fs.ErrNotExist)
err = kvs.UnsetKV("foo")
require.ErrorIs(t, err, fs.ErrNotExist)
}
func TestListKV(t *testing.T) {
kvs, err := NewMemoryKVS()
require.NoError(t, err)
err = kvs.SetKV("foo", "bar")
require.NoError(t, err)
err = kvs.SetKV("foz", "baz")
require.NoError(t, err)
err = kvs.SetKV("bar", "foo")
require.NoError(t, err)
list := kvs.ListKV("")
require.Equal(t, 3, len(list))
list = kvs.ListKV("f")
require.Equal(t, 2, len(list))
list = kvs.ListKV("b")
require.Equal(t, 1, len(list))
list = kvs.ListKV("fo")
require.Equal(t, 2, len(list))
list = kvs.ListKV("foo")
require.Equal(t, 1, len(list))
}
func TestLock(t *testing.T) {
kvs, err := NewMemoryKVS()
require.NoError(t, err)
until := time.Now().Add(5 * time.Second)
lock, err := kvs.CreateLock("foobar", until)
require.NoError(t, err)
require.Equal(t, until, lock.ValidUntil)
require.Eventually(t, func() bool {
select {
case <-lock.Expired():
return true
case <-time.After(10 * time.Millisecond):
return false
}
}, 10*time.Second, time.Second)
}
func TestLockCreate(t *testing.T) {
kvs, err := NewMemoryKVS()
require.NoError(t, err)
until := time.Now().Add(5 * time.Second)
lock, err := kvs.CreateLock("foobar", until)
require.NoError(t, err)
require.Equal(t, until, lock.ValidUntil)
_, err = kvs.CreateLock("foobar", until)
require.Error(t, err)
err = kvs.DeleteLock("foobar")
require.NoError(t, err)
}
func TestLockDelete(t *testing.T) {
kvs, err := NewMemoryKVS()
require.NoError(t, err)
err = kvs.DeleteLock("foobar")
require.Error(t, err)
until := time.Now().Add(5 * time.Second)
_, err = kvs.CreateLock("foobar", until)
require.NoError(t, err)
err = kvs.DeleteLock("foobar")
require.NoError(t, err)
}
func TestLocksList(t *testing.T) {
kvs, err := NewMemoryKVS()
require.NoError(t, err)
list := kvs.ListLocks()
require.Empty(t, list)
until := time.Now().Add(5 * time.Second)
_, err = kvs.CreateLock("foobar", until)
require.NoError(t, err)
list = kvs.ListLocks()
require.NotEmpty(t, list)
require.Equal(t, list["foobar"], until)
err = kvs.DeleteLock("foobar")
require.NoError(t, err)
list = kvs.ListLocks()
require.Empty(t, list)
}

168
cluster/tls.go Normal file
View File

@@ -0,0 +1,168 @@
package cluster
import (
"context"
"encoding/base64"
"fmt"
"io/fs"
"path"
"strings"
"sync"
"time"
"github.com/caddyserver/certmagic"
)
type clusterStorage struct {
kvs KVS
prefix string
locks map[string]*Lock
muLocks sync.Mutex
}
func NewClusterStorage(kvs KVS, prefix string) (certmagic.Storage, error) {
s := &clusterStorage{
kvs: kvs,
prefix: prefix,
locks: map[string]*Lock{},
}
return s, nil
}
func (s *clusterStorage) prefixKey(key string) string {
return path.Join(s.prefix, key)
}
func (s *clusterStorage) unprefixKey(key string) string {
return strings.TrimPrefix(key, s.prefix+"/")
}
func (s *clusterStorage) Lock(ctx context.Context, name string) error {
for {
lock, err := s.kvs.CreateLock(s.prefixKey(name), time.Now().Add(time.Minute))
if err == nil {
go func() {
<-lock.Expired()
s.Unlock(context.Background(), name)
}()
s.muLocks.Lock()
s.locks[name] = lock
s.muLocks.Unlock()
return nil
}
select {
case <-time.After(time.Second):
continue
case <-ctx.Done():
return ctx.Err()
case <-time.After(5 * time.Minute):
return fmt.Errorf("wasn't able to acquire lock")
}
}
}
func (s *clusterStorage) Unlock(ctx context.Context, name string) error {
err := s.kvs.DeleteLock(s.prefixKey(name))
if err != nil {
return err
}
s.muLocks.Lock()
delete(s.locks, name)
s.muLocks.Unlock()
return nil
}
// Store puts value at key.
func (s *clusterStorage) Store(ctx context.Context, key string, value []byte) error {
encodedValue := base64.StdEncoding.EncodeToString(value)
return s.kvs.SetKV(s.prefixKey(key), encodedValue)
}
// Load retrieves the value at key.
func (s *clusterStorage) Load(ctx context.Context, key string) ([]byte, error) {
encodedValue, _, err := s.kvs.GetKV(s.prefixKey(key))
if err != nil {
return nil, err
}
return base64.StdEncoding.DecodeString(encodedValue)
}
// Delete deletes key. An error should be
// returned only if the key still exists
// when the method returns.
func (s *clusterStorage) Delete(ctx context.Context, key string) error {
return s.kvs.UnsetKV(s.prefixKey(key))
}
// Exists returns true if the key exists
// and there was no error checking.
func (s *clusterStorage) Exists(ctx context.Context, key string) bool {
_, _, err := s.kvs.GetKV(s.prefixKey(key))
return err == nil
}
// List returns all keys that match prefix.
// If recursive is true, non-terminal keys
// will be enumerated (i.e. "directories"
// should be walked); otherwise, only keys
// prefixed exactly by prefix will be listed.
func (s *clusterStorage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) {
values := s.kvs.ListKV(s.prefixKey(prefix))
keys := []string{}
for key := range values {
keys = append(keys, s.unprefixKey(key))
}
if len(keys) == 0 {
return nil, fs.ErrNotExist
}
if recursive {
return keys, nil
}
prefix = strings.TrimSuffix(prefix, "/")
keyMap := map[string]struct{}{}
for _, key := range keys {
elms := strings.Split(strings.TrimPrefix(key, prefix+"/"), "/")
keyMap[elms[0]] = struct{}{}
}
keys = []string{}
for key := range keyMap {
keys = append(keys, path.Join(prefix, key))
}
return keys, nil
}
// Stat returns information about key.
func (s *clusterStorage) Stat(ctx context.Context, key string) (certmagic.KeyInfo, error) {
encodedValue, lastModified, err := s.kvs.GetKV(s.prefixKey(key))
if err != nil {
return certmagic.KeyInfo{}, err
}
value, err := base64.StdEncoding.DecodeString(encodedValue)
if err != nil {
return certmagic.KeyInfo{}, err
}
info := certmagic.KeyInfo{
Key: key,
Modified: lastModified,
Size: int64(len(value)),
IsTerminal: false,
}
return info, nil
}

164
cluster/tls_test.go Normal file
View File

@@ -0,0 +1,164 @@
package cluster
import (
"context"
"io/fs"
"path"
"testing"
"time"
"github.com/caddyserver/certmagic"
"github.com/stretchr/testify/require"
)
func setupStorage() (certmagic.Storage, error) {
kvs, err := NewMemoryKVS()
if err != nil {
return nil, err
}
return NewClusterStorage(kvs, "some_prefix")
}
func TestStorageStore(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.crt"), []byte("crt data"))
require.NoError(t, err)
}
func TestStorageExists(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt")
err = cs.Store(context.Background(), key, []byte("crt data"))
require.NoError(t, err)
exists := cs.Exists(context.Background(), key)
require.True(t, exists)
}
func TestStorageLoad(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt")
content := []byte("crt data")
err = cs.Store(context.Background(), key, content)
require.NoError(t, err)
contentLoded, err := cs.Load(context.Background(), key)
require.NoError(t, err)
require.Equal(t, content, contentLoded)
}
func TestStorageDelete(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt")
content := []byte("crt data")
err = cs.Store(context.Background(), key, content)
require.NoError(t, err)
err = cs.Delete(context.Background(), key)
require.NoError(t, err)
exists := cs.Exists(context.Background(), key)
require.False(t, exists)
contentLoaded, err := cs.Load(context.Background(), key)
require.Nil(t, contentLoaded)
require.ErrorIs(t, err, fs.ErrNotExist)
}
func TestStorageStat(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
key := path.Join("acme", "example.com", "sites", "example.com", "example.com.crt")
content := []byte("crt data")
err = cs.Store(context.Background(), key, content)
require.NoError(t, err)
info, err := cs.Stat(context.Background(), key)
require.NoError(t, err)
require.Equal(t, key, info.Key)
}
func TestStorageList(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.crt"), []byte("crt"))
require.NoError(t, err)
err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.key"), []byte("key"))
require.NoError(t, err)
err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.json"), []byte("meta"))
require.NoError(t, err)
keys, err := cs.List(context.Background(), path.Join("acme", "example.com", "sites", "example.com"), true)
require.NoError(t, err)
require.Len(t, keys, 3)
require.Contains(t, keys, path.Join("acme", "example.com", "sites", "example.com", "example.com.crt"))
}
func TestStorageListNonRecursive(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.crt"), []byte("crt"))
require.NoError(t, err)
err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.key"), []byte("key"))
require.NoError(t, err)
err = cs.Store(context.Background(), path.Join("acme", "example.com", "sites", "example.com", "example.com.json"), []byte("meta"))
require.NoError(t, err)
keys, err := cs.List(context.Background(), path.Join("acme", "example.com", "sites"), false)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Contains(t, keys, path.Join("acme", "example.com", "sites", "example.com"))
}
func TestStorageLockUnlock(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
lockKey := path.Join("acme", "example.com", "sites", "example.com", "lock")
err = cs.Lock(context.Background(), lockKey)
require.NoError(t, err)
err = cs.Unlock(context.Background(), lockKey)
require.NoError(t, err)
}
func TestStorageTwoLocks(t *testing.T) {
cs, err := setupStorage()
require.NoError(t, err)
lockKey := path.Join("acme", "example.com", "sites", "example.com", "lock")
err = cs.Lock(context.Background(), lockKey)
require.NoError(t, err)
go time.AfterFunc(5*time.Second, func() {
err := cs.Unlock(context.Background(), lockKey)
require.NoError(t, err)
})
err = cs.Lock(context.Background(), lockKey)
require.NoError(t, err)
err = cs.Unlock(context.Background(), lockKey)
require.NoError(t, err)
}