diff --git a/app/api/api.go b/app/api/api.go index e5bfc30e..8d0ff70a 100644 --- a/app/api/api.go +++ b/app/api/api.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/caddyserver/certmagic" "github.com/datarhei/core/v16/app" "github.com/datarhei/core/v16/autocert" "github.com/datarhei/core/v16/cluster" @@ -50,7 +51,6 @@ import ( srturl "github.com/datarhei/core/v16/srt/url" "github.com/datarhei/core/v16/update" - "github.com/caddyserver/certmagic" "github.com/lestrrat-go/strftime" "go.uber.org/automaxprocs/maxprocs" ) @@ -478,10 +478,18 @@ func (a *api) start(ctx context.Context) error { } if a.cluster == nil { + var storage certmagic.Storage + storage = &certmagic.FileStorage{ + Path: filepath.Join(cfg.DB.Dir, "cert"), + } + + if len(cfg.TLS.Secret) != 0 { + crypto := autocert.NewCrypto(cfg.TLS.Secret) + storage = autocert.NewCryptoStorage(storage, crypto) + } + manager, err := autocert.New(autocert.Config{ - Storage: &certmagic.FileStorage{ - Path: filepath.Join(cfg.DB.Dir, "cert"), - }, + Storage: storage, DefaultHostname: cfg.Host.Name[0], EmailAddress: cfg.TLS.Email, IsProduction: !cfg.TLS.Staging, diff --git a/autocert/autocert.go b/autocert/autocert.go index 4285dde3..ce6fbfff 100644 --- a/autocert/autocert.go +++ b/autocert/autocert.go @@ -10,6 +10,7 @@ import ( "time" "github.com/datarhei/core/v16/log" + "github.com/datarhei/core/v16/slices" "github.com/caddyserver/certmagic" "github.com/klauspost/cpuid/v2" @@ -168,7 +169,7 @@ func (m *manager) HTTPChallengeResolver(ctx context.Context, listenAddress strin // AcquireCertificates tries to acquire the certificates for the given hostnames synchronously. func (m *manager) AcquireCertificates(ctx context.Context, hostnames []string) error { m.lock.Lock() - added, removed := diffStringSlice(hostnames, m.hostnames) + added, removed := slices.Diff(hostnames, m.hostnames) m.lock.Unlock() var err error @@ -201,7 +202,7 @@ func (m *manager) AcquireCertificates(ctx context.Context, hostnames []string) e // ManageCertificates is the same as AcquireCertificates but it does it in the background. func (m *manager) ManageCertificates(ctx context.Context, hostnames []string) error { m.lock.Lock() - added, removed := diffStringSlice(hostnames, m.hostnames) + added, removed := slices.Diff(hostnames, m.hostnames) m.hostnames = make([]string, len(hostnames)) copy(m.hostnames, hostnames) m.lock.Unlock() @@ -286,30 +287,3 @@ var ( tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, } ) - -// diffHostnames returns a list of newly added hostnames and a list of removed hostnames based -// the provided list and the list of currently managed hostnames. -func diffStringSlice(next, current []string) ([]string, []string) { - added, removed := []string{}, []string{} - - currentMap := map[string]struct{}{} - - for _, name := range current { - currentMap[name] = struct{}{} - } - - for _, name := range next { - if _, ok := currentMap[name]; ok { - delete(currentMap, name) - continue - } - - added = append(added, name) - } - - for name := range currentMap { - removed = append(removed, name) - } - - return added, removed -} diff --git a/autocert/crypto.go b/autocert/crypto.go new file mode 100644 index 00000000..4169dc34 --- /dev/null +++ b/autocert/crypto.go @@ -0,0 +1,106 @@ +package autocert + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "io" + + "golang.org/x/crypto/scrypt" +) + +type Crypto interface { + // Encrypt encrypts the given data or returns error if encrypting the data is not possible. + Encrypt(data []byte) ([]byte, error) + + // Decrypt decrypts the given data or returns error if decrypting the data is not possible. + Decrypt(data []byte) ([]byte, error) +} + +type crypto struct { + secret []byte +} + +// NewCrypto returns a new implementation of the the Crypto interface that encrypts/decrypts +// the given data with a key and salt derived from the provided secret. +// Based on https://itnext.io/encrypt-data-with-a-password-in-go-b5366384e291 +func NewCrypto(secret string) Crypto { + c := &crypto{ + secret: []byte(secret), + } + + return c +} + +func (c *crypto) Encrypt(data []byte) ([]byte, error) { + key, salt, err := c.deriveKey(nil) + if err != nil { + return nil, err + } + + blockCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + // The first gcm.NonceSize() are the nonce + ciphertext := gcm.Seal(nonce, nonce, data, nil) + // The last 32 bytes are the salt + ciphertext = append(ciphertext, salt...) + + return ciphertext, nil +} + +func (c *crypto) Decrypt(data []byte) ([]byte, error) { + // The last 32 bytes are the salt + salt, data := data[len(data)-32:], data[:len(data)-32] + key, _, err := c.deriveKey(salt) + if err != nil { + return nil, err + } + + blockCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, err + } + + // The first gcm.NonceSize() are the nonce + nonce, ciphertext := data[:gcm.NonceSize()], data[gcm.NonceSize():] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + + return plaintext, nil +} + +func (c *crypto) deriveKey(salt []byte) ([]byte, []byte, error) { + if salt == nil { + salt = make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + return nil, nil, err + } + } + + key, err := scrypt.Key(c.secret, salt, 32768, 8, 1, 32) + if err != nil { + return nil, nil, err + } + + return key, salt, nil +} diff --git a/autocert/crypto_test.go b/autocert/crypto_test.go new file mode 100644 index 00000000..c260aa47 --- /dev/null +++ b/autocert/crypto_test.go @@ -0,0 +1,39 @@ +package autocert + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEncryptDecrypt(t *testing.T) { + c := NewCrypto("foobar") + + data := "top secret" + + encryptedData, err := c.Encrypt([]byte(data)) + require.NoError(t, err) + require.NotEqual(t, []byte(data), encryptedData) + + decryptedData, err := c.Decrypt(encryptedData) + require.NoError(t, err) + require.Equal(t, []byte(data), decryptedData) +} + +func TestEncryptDecryptWrongSecret(t *testing.T) { + c1 := NewCrypto("foobar") + c2 := NewCrypto("foobaz") + + data := "top secret" + + encryptedData, err := c1.Encrypt([]byte(data)) + require.NoError(t, err) + require.NotEqual(t, []byte(data), encryptedData) + + _, err = c2.Decrypt(encryptedData) + require.Error(t, err) + + decryptedData, err := c1.Decrypt(encryptedData) + require.NoError(t, err) + require.Equal(t, []byte(data), decryptedData) +} diff --git a/autocert/storage.go b/autocert/storage.go new file mode 100644 index 00000000..34b2c67a --- /dev/null +++ b/autocert/storage.go @@ -0,0 +1,81 @@ +package autocert + +import ( + "context" + + "github.com/caddyserver/certmagic" +) + +type cryptoStorage struct { + secret Crypto + + storage certmagic.Storage +} + +func NewCryptoStorage(storage certmagic.Storage, secret Crypto) certmagic.Storage { + s := &cryptoStorage{ + secret: secret, + storage: storage, + } + + return s +} + +func (s *cryptoStorage) Lock(ctx context.Context, name string) error { + return s.storage.Lock(ctx, name) +} + +func (s *cryptoStorage) Unlock(ctx context.Context, name string) error { + return s.storage.Unlock(ctx, name) +} + +func (s *cryptoStorage) Store(ctx context.Context, key string, value []byte) error { + encryptedValue, err := s.secret.Encrypt(value) + if err != nil { + return err + } + + return s.storage.Store(ctx, key, encryptedValue) +} + +func (s *cryptoStorage) Load(ctx context.Context, key string) ([]byte, error) { + encryptedValue, err := s.storage.Load(ctx, key) + if err != nil { + return nil, err + } + + value, err := s.secret.Decrypt(encryptedValue) + if err != nil { + return nil, err + } + + return value, nil +} + +func (s *cryptoStorage) Delete(ctx context.Context, key string) error { + return s.storage.Delete(ctx, key) +} + +func (s *cryptoStorage) Exists(ctx context.Context, key string) bool { + return s.storage.Exists(ctx, key) +} + +func (s *cryptoStorage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) { + return s.storage.List(ctx, prefix, recursive) +} + +func (s *cryptoStorage) Stat(ctx context.Context, key string) (certmagic.KeyInfo, error) { + keyInfo, err := s.storage.Stat(ctx, key) + if err != nil { + return certmagic.KeyInfo{}, err + } + + value, err := s.Load(ctx, key) + if err != nil { + return certmagic.KeyInfo{}, err + } + + keyInfo.Size = int64(len(value)) + + return keyInfo, nil +} diff --git a/autocert/storage_test.go b/autocert/storage_test.go new file mode 100644 index 00000000..4da4187a --- /dev/null +++ b/autocert/storage_test.go @@ -0,0 +1,98 @@ +package autocert + +import ( + "context" + "io/fs" + "os" + "testing" + + "github.com/caddyserver/certmagic" + "github.com/stretchr/testify/require" +) + +func getCryptoStorage() certmagic.Storage { + s := &certmagic.FileStorage{ + Path: "./testing", + } + c := NewCrypto("secret") + + sc := NewCryptoStorage(s, c) + + return sc +} + +func TestFileStorageStoreLoad(t *testing.T) { + s := getCryptoStorage() + defer os.RemoveAll("./testing/") + + data := []byte("some data") + ctx := context.Background() + + err := s.Store(ctx, "foo", data) + require.NoError(t, err) + + loadedData, err := s.Load(ctx, "foo") + require.NoError(t, err) + require.Equal(t, data, loadedData) +} + +func TestFileStorageDelete(t *testing.T) { + s := getCryptoStorage() + defer os.RemoveAll("./testing/") + + data := []byte("some data") + ctx := context.Background() + + err := s.Delete(ctx, "foo") + require.ErrorIs(t, err, fs.ErrNotExist) + + err = s.Store(ctx, "foo", data) + require.NoError(t, err) + + err = s.Delete(ctx, "foo") + require.NoError(t, err) + + _, err = s.Load(ctx, "foo") + require.Error(t, err, fs.ErrNotExist) +} + +func TestFileStorageExists(t *testing.T) { + s := getCryptoStorage() + defer os.RemoveAll("./testing/") + + data := []byte("some data") + ctx := context.Background() + + b := s.Exists(ctx, "foo") + require.False(t, b) + + err := s.Store(ctx, "foo", data) + require.NoError(t, err) + + b = s.Exists(ctx, "foo") + require.True(t, b) + + err = s.Delete(ctx, "foo") + require.NoError(t, err) + + b = s.Exists(ctx, "foo") + require.False(t, b) +} + +func TestFileStorageStat(t *testing.T) { + s := getCryptoStorage() + defer os.RemoveAll("./testing/") + + data := []byte("some data") + ctx := context.Background() + + err := s.Store(ctx, "foo", data) + require.NoError(t, err) + + info, err := s.Stat(ctx, "foo") + require.NoError(t, err) + + require.Equal(t, "foo", info.Key) + require.Equal(t, int64(len(data)), info.Size) + require.Equal(t, true, info.IsTerminal) +} diff --git a/cluster/tls.go b/cluster/autocert/storage.go similarity index 77% rename from cluster/tls.go rename to cluster/autocert/storage.go index a8649c62..862c14ec 100644 --- a/cluster/tls.go +++ b/cluster/autocert/storage.go @@ -1,4 +1,4 @@ -package cluster +package autocert import ( "context" @@ -11,23 +11,24 @@ import ( "time" "github.com/caddyserver/certmagic" + "github.com/datarhei/core/v16/cluster/kvs" "github.com/datarhei/core/v16/log" ) -type clusterStorage struct { - kvs KVS +type storage struct { + kvs kvs.KVS prefix string - locks map[string]*Lock + locks map[string]*kvs.Lock muLocks sync.Mutex logger log.Logger } -func NewClusterStorage(kvs KVS, prefix string, logger log.Logger) (certmagic.Storage, error) { - s := &clusterStorage{ - kvs: kvs, +func NewStorage(kv kvs.KVS, prefix string, logger log.Logger) (certmagic.Storage, error) { + s := &storage{ + kvs: kv, prefix: prefix, - locks: map[string]*Lock{}, + locks: map[string]*kvs.Lock{}, logger: logger, } @@ -38,15 +39,15 @@ func NewClusterStorage(kvs KVS, prefix string, logger log.Logger) (certmagic.Sto return s, nil } -func (s *clusterStorage) prefixKey(key string) string { +func (s *storage) prefixKey(key string) string { return path.Join(s.prefix, key) } -func (s *clusterStorage) unprefixKey(key string) string { +func (s *storage) unprefixKey(key string) string { return strings.TrimPrefix(key, s.prefix+"/") } -func (s *clusterStorage) Lock(ctx context.Context, name string) error { +func (s *storage) Lock(ctx context.Context, name string) error { s.logger.Debug().WithField("name", name).Log("StorageLock") for { lock, err := s.kvs.CreateLock(s.prefixKey(name), time.Now().Add(time.Minute)) @@ -73,7 +74,7 @@ func (s *clusterStorage) Lock(ctx context.Context, name string) error { } } -func (s *clusterStorage) Unlock(ctx context.Context, name string) error { +func (s *storage) Unlock(ctx context.Context, name string) error { s.logger.Debug().WithField("name", name).Log("StorageUnlock") err := s.kvs.DeleteLock(s.prefixKey(name)) if err != nil { @@ -92,14 +93,14 @@ func (s *clusterStorage) Unlock(ctx context.Context, name string) error { } // Store puts value at key. -func (s *clusterStorage) Store(ctx context.Context, key string, value []byte) error { +func (s *storage) Store(ctx context.Context, key string, value []byte) error { s.logger.Debug().WithField("key", key).Log("StorageStore") encodedValue := base64.StdEncoding.EncodeToString(value) return s.kvs.SetKV(s.prefixKey(key), encodedValue) } // Load retrieves the value at key. -func (s *clusterStorage) Load(ctx context.Context, key string) ([]byte, error) { +func (s *storage) Load(ctx context.Context, key string) ([]byte, error) { s.logger.Debug().WithField("key", key).Log("StorageLoad") encodedValue, _, err := s.kvs.GetKV(s.prefixKey(key)) if err != nil { @@ -113,14 +114,14 @@ func (s *clusterStorage) Load(ctx context.Context, key string) ([]byte, error) { // Delete deletes key. An error should be // returned only if the key still exists // when the method returns. -func (s *clusterStorage) Delete(ctx context.Context, key string) error { +func (s *storage) Delete(ctx context.Context, key string) error { s.logger.Debug().WithField("key", key).Log("StorageDelete") return s.kvs.UnsetKV(s.prefixKey(key)) } // Exists returns true if the key exists // and there was no error checking. -func (s *clusterStorage) Exists(ctx context.Context, key string) bool { +func (s *storage) Exists(ctx context.Context, key string) bool { s.logger.Debug().WithField("key", key).Log("StorageExits") _, _, err := s.kvs.GetKV(s.prefixKey(key)) return err == nil @@ -131,7 +132,7 @@ func (s *clusterStorage) Exists(ctx context.Context, key string) bool { // will be enumerated (i.e. "directories" // should be walked); otherwise, only keys // prefixed exactly by prefix will be listed. -func (s *clusterStorage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) { +func (s *storage) List(ctx context.Context, prefix string, recursive bool) ([]string, error) { s.logger.Debug().WithField("prefix", prefix).Log("StorageList") values := s.kvs.ListKV(s.prefixKey(prefix)) @@ -166,7 +167,7 @@ func (s *clusterStorage) List(ctx context.Context, prefix string, recursive bool } // Stat returns information about key. -func (s *clusterStorage) Stat(ctx context.Context, key string) (certmagic.KeyInfo, error) { +func (s *storage) Stat(ctx context.Context, key string) (certmagic.KeyInfo, error) { s.logger.Debug().WithField("key", key).Log("StorageStat") encodedValue, lastModified, err := s.kvs.GetKV(s.prefixKey(key)) if err != nil { diff --git a/cluster/tls_test.go b/cluster/autocert/storage_test.go similarity index 97% rename from cluster/tls_test.go rename to cluster/autocert/storage_test.go index 625d0cbf..9736c784 100644 --- a/cluster/tls_test.go +++ b/cluster/autocert/storage_test.go @@ -1,4 +1,4 @@ -package cluster +package autocert import ( "context" @@ -8,16 +8,17 @@ import ( "time" "github.com/caddyserver/certmagic" + "github.com/datarhei/core/v16/cluster/kvs" "github.com/stretchr/testify/require" ) func setupStorage() (certmagic.Storage, error) { - kvs, err := NewMemoryKVS() + kvs, err := kvs.NewMemoryKVS() if err != nil { return nil, err } - return NewClusterStorage(kvs, "some_prefix", nil) + return NewStorage(kvs, "some_prefix", nil) } func TestStorageStore(t *testing.T) { diff --git a/cluster/cluster.go b/cluster/cluster.go index 8f6f2bb5..1f04aec0 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -14,10 +14,12 @@ import ( "time" "github.com/datarhei/core/v16/autocert" + clusterautocert "github.com/datarhei/core/v16/cluster/autocert" apiclient "github.com/datarhei/core/v16/cluster/client" "github.com/datarhei/core/v16/cluster/forwarder" clusteriam "github.com/datarhei/core/v16/cluster/iam" clusteriamadapter "github.com/datarhei/core/v16/cluster/iam/adapter" + "github.com/datarhei/core/v16/cluster/kvs" "github.com/datarhei/core/v16/cluster/proxy" "github.com/datarhei/core/v16/cluster/raft" "github.com/datarhei/core/v16/cluster/store" @@ -72,7 +74,7 @@ type Cluster interface { SetPolicies(origin, name string, policies []iamaccess.Policy) error RemoveIdentity(origin string, name string) error - CreateLock(origin string, name string, validUntil time.Time) (*Lock, error) + CreateLock(origin string, name string, validUntil time.Time) (*kvs.Lock, error) DeleteLock(origin string, name string) error ListLocks() map[string]time.Time @@ -434,11 +436,15 @@ func (c *cluster) setup(ctx context.Context) error { return fmt.Errorf("tls: cluster KVS: %w", err) } - storage, err := NewClusterStorage(kvs, "core-cluster-certificates", c.logger.WithComponent("KVS")) + storage, err := clusterautocert.NewStorage(kvs, "core-cluster-certificates", c.logger.WithComponent("KVS")) if err != nil { return fmt.Errorf("tls: certificate store: %w", err) } + if len(c.config.TLS.Secret) != 0 { + storage = autocert.NewCryptoStorage(storage, autocert.NewCrypto(c.config.TLS.Secret)) + } + manager, err := autocert.New(autocert.Config{ Storage: storage, DefaultHostname: hostnames[0], @@ -1164,6 +1170,10 @@ func verifyClusterConfig(local, remote *config.Config) error { if local.TLS.Staging != remote.TLS.Staging { return fmt.Errorf("tls.staging is different") } + + if local.TLS.Secret != remote.TLS.Secret { + return fmt.Errorf("tls.secret is different") + } } } @@ -1460,7 +1470,7 @@ func (c *cluster) RemoveIdentity(origin string, name string) error { return c.applyCommand(cmd) } -func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) (*Lock, error) { +func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) (*kvs.Lock, error) { if ok, _ := c.IsClusterDegraded(); ok { return nil, ErrDegraded } @@ -1471,7 +1481,7 @@ func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) ( return nil, err } - l := &Lock{ + l := &kvs.Lock{ ValidUntil: validUntil, } @@ -1491,7 +1501,7 @@ func (c *cluster) CreateLock(origin string, name string, validUntil time.Time) ( return nil, err } - l := &Lock{ + l := &kvs.Lock{ ValidUntil: validUntil, } diff --git a/cluster/kvs.go b/cluster/kvs.go index 22f03ce0..1980e641 100644 --- a/cluster/kvs.go +++ b/cluster/kvs.go @@ -1,72 +1,19 @@ package cluster import ( - "context" - "fmt" - "io/fs" - "strings" - "sync" "time" + "github.com/datarhei/core/v16/cluster/kvs" "github.com/datarhei/core/v16/cluster/store" "github.com/datarhei/core/v16/log" ) -type KVS interface { - CreateLock(name string, validUntil time.Time) (*Lock, error) - DeleteLock(name string) error - ListLocks() map[string]time.Time - - SetKV(key, value string) error - UnsetKV(key string) error - GetKV(key string) (string, time.Time, error) - ListKV(prefix string) map[string]store.Value -} - -type Lock struct { - ValidUntil time.Time - ctx context.Context - cancel context.CancelFunc - lock sync.Mutex -} - -func (l *Lock) Expired() <-chan struct{} { - l.lock.Lock() - defer l.lock.Unlock() - - if l.ctx == nil { - l.ctx, l.cancel = context.WithDeadline(context.Background(), l.ValidUntil) - - go func(l *Lock) { - <-l.ctx.Done() - - l.lock.Lock() - defer l.lock.Unlock() - if l.cancel != nil { - l.cancel() - } - }(l) - } - - return l.ctx.Done() -} - -func (l *Lock) Unlock() { - l.lock.Lock() - defer l.lock.Unlock() - - if l.cancel != nil { - l.ValidUntil = time.Now() - l.cancel() - } -} - type clusterKVS struct { cluster Cluster logger log.Logger } -func NewClusterKVS(cluster Cluster, logger log.Logger) (KVS, error) { +func NewClusterKVS(cluster Cluster, logger log.Logger) (kvs.KVS, error) { s := &clusterKVS{ cluster: cluster, logger: logger, @@ -79,7 +26,7 @@ func NewClusterKVS(cluster Cluster, logger log.Logger) (KVS, error) { return s, nil } -func (s *clusterKVS) CreateLock(name string, validUntil time.Time) (*Lock, error) { +func (s *clusterKVS) CreateLock(name string, validUntil time.Time) (*kvs.Lock, error) { s.logger.Debug().WithFields(log.Fields{ "name": name, "valid_until": validUntil, @@ -119,123 +66,3 @@ func (s *clusterKVS) ListKV(prefix string) map[string]store.Value { s.logger.Debug().Log("List KV") return s.cluster.ListKV(prefix) } - -type memKVS struct { - lock sync.Mutex - - locks map[string]*Lock - values map[string]store.Value -} - -func NewMemoryKVS() (KVS, error) { - return &memKVS{ - locks: map[string]*Lock{}, - values: map[string]store.Value{}, - }, nil -} - -func (s *memKVS) CreateLock(name string, validUntil time.Time) (*Lock, error) { - s.lock.Lock() - defer s.lock.Unlock() - - l, ok := s.locks[name] - - if ok { - if time.Now().Before(l.ValidUntil) { - return nil, fmt.Errorf("the lock with the name '%s' already exists", name) - } - } - - l = &Lock{ - ValidUntil: validUntil, - } - - s.locks[name] = l - - return l, nil -} - -func (s *memKVS) DeleteLock(name string) error { - s.lock.Lock() - defer s.lock.Unlock() - - lock, ok := s.locks[name] - if !ok { - return fmt.Errorf("the lock with the name '%s' doesn't exist", name) - } - - lock.Unlock() - - delete(s.locks, name) - - return nil -} - -func (s *memKVS) ListLocks() map[string]time.Time { - s.lock.Lock() - defer s.lock.Unlock() - - m := map[string]time.Time{} - - for key, lock := range s.locks { - m[key] = lock.ValidUntil - } - - return m -} - -func (s *memKVS) SetKV(key, value string) error { - s.lock.Lock() - defer s.lock.Unlock() - - v := s.values[key] - - v.Value = value - v.UpdatedAt = time.Now() - - s.values[key] = v - - return nil -} - -func (s *memKVS) UnsetKV(key string) error { - s.lock.Lock() - defer s.lock.Unlock() - - if _, ok := s.values[key]; !ok { - return fs.ErrNotExist - } - - delete(s.values, key) - - return nil -} - -func (s *memKVS) GetKV(key string) (string, time.Time, error) { - s.lock.Lock() - defer s.lock.Unlock() - - v, ok := s.values[key] - if !ok { - return "", time.Time{}, fs.ErrNotExist - } - - return v.Value, v.UpdatedAt, nil -} - -func (s *memKVS) ListKV(prefix string) map[string]store.Value { - s.lock.Lock() - defer s.lock.Unlock() - - m := map[string]store.Value{} - - for key, value := range s.values { - if !strings.HasPrefix(key, prefix) { - continue - } - - m[key] = value - } - - return m -} diff --git a/cluster/kvs/kvs.go b/cluster/kvs/kvs.go new file mode 100644 index 00000000..d068c920 --- /dev/null +++ b/cluster/kvs/kvs.go @@ -0,0 +1,58 @@ +package kvs + +import ( + "context" + "sync" + "time" + + "github.com/datarhei/core/v16/cluster/store" +) + +type KVS interface { + CreateLock(name string, validUntil time.Time) (*Lock, error) + DeleteLock(name string) error + ListLocks() map[string]time.Time + + SetKV(key, value string) error + UnsetKV(key string) error + GetKV(key string) (string, time.Time, error) + ListKV(prefix string) map[string]store.Value +} + +type Lock struct { + ValidUntil time.Time + ctx context.Context + cancel context.CancelFunc + lock sync.Mutex +} + +func (l *Lock) Expired() <-chan struct{} { + l.lock.Lock() + defer l.lock.Unlock() + + if l.ctx == nil { + l.ctx, l.cancel = context.WithDeadline(context.Background(), l.ValidUntil) + + go func(l *Lock) { + <-l.ctx.Done() + + l.lock.Lock() + defer l.lock.Unlock() + if l.cancel != nil { + l.cancel() + } + }(l) + } + + return l.ctx.Done() +} + +func (l *Lock) Unlock() { + l.lock.Lock() + defer l.lock.Unlock() + + if l.cancel != nil { + l.ValidUntil = time.Now() + l.cancel() + } +} diff --git a/cluster/kvs/memory.go b/cluster/kvs/memory.go new file mode 100644 index 00000000..b04b5fd4 --- /dev/null +++ b/cluster/kvs/memory.go @@ -0,0 +1,131 @@ +package kvs + +import ( + "fmt" + "io/fs" + "strings" + "sync" + "time" + + "github.com/datarhei/core/v16/cluster/store" +) + +type memKVS struct { + lock sync.Mutex + + locks map[string]*Lock + values map[string]store.Value +} + +func NewMemoryKVS() (KVS, error) { + return &memKVS{ + locks: map[string]*Lock{}, + values: map[string]store.Value{}, + }, nil +} + +func (s *memKVS) CreateLock(name string, validUntil time.Time) (*Lock, error) { + s.lock.Lock() + defer s.lock.Unlock() + + l, ok := s.locks[name] + + if ok { + if time.Now().Before(l.ValidUntil) { + return nil, fmt.Errorf("the lock with the name '%s' already exists", name) + } + } + + l = &Lock{ + ValidUntil: validUntil, + } + + s.locks[name] = l + + return l, nil +} + +func (s *memKVS) DeleteLock(name string) error { + s.lock.Lock() + defer s.lock.Unlock() + + lock, ok := s.locks[name] + if !ok { + return fmt.Errorf("the lock with the name '%s' doesn't exist", name) + } + + lock.Unlock() + + delete(s.locks, name) + + return nil +} + +func (s *memKVS) ListLocks() map[string]time.Time { + s.lock.Lock() + defer s.lock.Unlock() + + m := map[string]time.Time{} + + for key, lock := range s.locks { + m[key] = lock.ValidUntil + } + + return m +} + +func (s *memKVS) SetKV(key, value string) error { + s.lock.Lock() + defer s.lock.Unlock() + + v := s.values[key] + + v.Value = value + v.UpdatedAt = time.Now() + + s.values[key] = v + + return nil +} + +func (s *memKVS) UnsetKV(key string) error { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.values[key]; !ok { + return fs.ErrNotExist + } + + delete(s.values, key) + + return nil +} + +func (s *memKVS) GetKV(key string) (string, time.Time, error) { + s.lock.Lock() + defer s.lock.Unlock() + + v, ok := s.values[key] + if !ok { + return "", time.Time{}, fs.ErrNotExist + } + + return v.Value, v.UpdatedAt, nil +} + +func (s *memKVS) ListKV(prefix string) map[string]store.Value { + s.lock.Lock() + defer s.lock.Unlock() + + m := map[string]store.Value{} + + for key, value := range s.values { + if !strings.HasPrefix(key, prefix) { + continue + } + + m[key] = value + } + + return m +} diff --git a/cluster/kvs_test.go b/cluster/kvs/memory_test.go similarity index 99% rename from cluster/kvs_test.go rename to cluster/kvs/memory_test.go index 900cb38f..b2e49f46 100644 --- a/cluster/kvs_test.go +++ b/cluster/kvs/memory_test.go @@ -1,4 +1,4 @@ -package cluster +package kvs import ( "io/fs" diff --git a/config/config.go b/config/config.go index c2a48aa4..bf08c25f 100644 --- a/config/config.go +++ b/config/config.go @@ -11,6 +11,7 @@ import ( "github.com/datarhei/core/v16/config/vars" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/math/rand" + "github.com/datarhei/core/v16/slices" haikunator "github.com/atrox/haikunatorgo/v2" "github.com/google/uuid" @@ -108,35 +109,35 @@ func (d *Config) Clone() *Config { data.Resources = d.Resources data.Cluster = d.Cluster - data.Log.Topics = copy.Slice(d.Log.Topics) + data.Log.Topics = slices.Copy(d.Log.Topics) - data.Host.Name = copy.Slice(d.Host.Name) + data.Host.Name = slices.Copy(d.Host.Name) - data.API.Access.HTTP.Allow = copy.Slice(d.API.Access.HTTP.Allow) - data.API.Access.HTTP.Block = copy.Slice(d.API.Access.HTTP.Block) - data.API.Access.HTTPS.Allow = copy.Slice(d.API.Access.HTTPS.Allow) - data.API.Access.HTTPS.Block = copy.Slice(d.API.Access.HTTPS.Block) + data.API.Access.HTTP.Allow = slices.Copy(d.API.Access.HTTP.Allow) + data.API.Access.HTTP.Block = slices.Copy(d.API.Access.HTTP.Block) + data.API.Access.HTTPS.Allow = slices.Copy(d.API.Access.HTTPS.Allow) + data.API.Access.HTTPS.Block = slices.Copy(d.API.Access.HTTPS.Block) data.API.Auth.Auth0.Tenants = copy.TenantSlice(d.API.Auth.Auth0.Tenants) - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) - data.Storage.Disk.Cache.Types.Allow = copy.Slice(d.Storage.Disk.Cache.Types.Allow) - data.Storage.Disk.Cache.Types.Block = copy.Slice(d.Storage.Disk.Cache.Types.Block) - data.Storage.S3 = copy.Slice(d.Storage.S3) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) + data.Storage.Disk.Cache.Types.Allow = slices.Copy(d.Storage.Disk.Cache.Types.Allow) + data.Storage.Disk.Cache.Types.Block = slices.Copy(d.Storage.Disk.Cache.Types.Block) + data.Storage.S3 = slices.Copy(d.Storage.S3) - data.FFmpeg.Access.Input.Allow = copy.Slice(d.FFmpeg.Access.Input.Allow) - data.FFmpeg.Access.Input.Block = copy.Slice(d.FFmpeg.Access.Input.Block) - data.FFmpeg.Access.Output.Allow = copy.Slice(d.FFmpeg.Access.Output.Allow) - data.FFmpeg.Access.Output.Block = copy.Slice(d.FFmpeg.Access.Output.Block) + data.FFmpeg.Access.Input.Allow = slices.Copy(d.FFmpeg.Access.Input.Allow) + data.FFmpeg.Access.Input.Block = slices.Copy(d.FFmpeg.Access.Input.Block) + data.FFmpeg.Access.Output.Allow = slices.Copy(d.FFmpeg.Access.Output.Allow) + data.FFmpeg.Access.Output.Block = slices.Copy(d.FFmpeg.Access.Output.Block) - data.Sessions.IPIgnoreList = copy.Slice(d.Sessions.IPIgnoreList) + data.Sessions.IPIgnoreList = slices.Copy(d.Sessions.IPIgnoreList) - data.SRT.Log.Topics = copy.Slice(d.SRT.Log.Topics) + data.SRT.Log.Topics = slices.Copy(d.SRT.Log.Topics) - data.Router.BlockedPrefixes = copy.Slice(d.Router.BlockedPrefixes) + data.Router.BlockedPrefixes = slices.Copy(d.Router.BlockedPrefixes) data.Router.Routes = copy.StringMap(d.Router.Routes) - data.Cluster.Peers = copy.Slice(d.Cluster.Peers) + data.Cluster.Peers = slices.Copy(d.Cluster.Peers) data.vars.Transfer(&d.vars) @@ -187,6 +188,7 @@ func (d *Config) init() { d.vars.Register(value.NewBool(&d.TLS.Auto, false), "tls.auto", "CORE_TLS_AUTO", nil, "Enable Let's Encrypt certificate", false, false) d.vars.Register(value.NewEmail(&d.TLS.Email, "cert@datarhei.com"), "tls.email", "CORE_TLS_EMAIL", nil, "Email for Let's Encrypt registration", false, false) d.vars.Register(value.NewBool(&d.TLS.Staging, false), "tls.staging", "CORE_TLS_STAGING", nil, "Use Let's Encrypt staging CA", false, false) + d.vars.Register(value.NewString(&d.TLS.Secret, ""), "tls.secret", "CORE_TLS_SECRET", nil, "Use this secret to encrypt automatic certificates on the storage", false, true) d.vars.Register(value.NewFile(&d.TLS.CertFile, "", d.fs), "tls.cert_file", "CORE_TLS_CERT_FILE", []string{"CORE_TLS_CERTFILE"}, "Path to certificate file in PEM format", false, false) d.vars.Register(value.NewFile(&d.TLS.KeyFile, "", d.fs), "tls.key_file", "CORE_TLS_KEY_FILE", []string{"CORE_TLS_KEYFILE"}, "Path to key file in PEM format", false, false) diff --git a/config/copy/copy.go b/config/copy/copy.go index 2541bf48..ac48db13 100644 --- a/config/copy/copy.go +++ b/config/copy/copy.go @@ -1,6 +1,9 @@ package copy -import "github.com/datarhei/core/v16/config/value" +import ( + "github.com/datarhei/core/v16/config/value" + "github.com/datarhei/core/v16/slices" +) func StringMap(src map[string]string) map[string]string { dst := make(map[string]string) @@ -13,18 +16,11 @@ func StringMap(src map[string]string) map[string]string { } func TenantSlice(src []value.Auth0Tenant) []value.Auth0Tenant { - dst := Slice(src) + dst := slices.Copy(src) for i, t := range src { - dst[i].Users = Slice(t.Users) + dst[i].Users = slices.Copy(t.Users) } return dst } - -func Slice[T any](src []T) []T { - dst := make([]T, len(src)) - copy(dst, src) - - return dst -} diff --git a/config/data.go b/config/data.go index 6b1c490c..7e55af8b 100644 --- a/config/data.go +++ b/config/data.go @@ -7,6 +7,7 @@ import ( v2 "github.com/datarhei/core/v16/config/v2" "github.com/datarhei/core/v16/config/value" "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/slices" ) // Data is the actual configuration data for the app @@ -63,6 +64,7 @@ type Data struct { Auto bool `json:"auto"` Email string `json:"email"` Staging bool `json:"staging"` + Secret string `json:"secret"` CertFile string `json:"cert_file"` KeyFile string `json:"key_file"` } `json:"tls"` @@ -214,45 +216,45 @@ func MergeV2toV3(data *Data, d *v2.Data) (*Data, error) { data.Service = d.Service data.Router = d.Router - data.Log.Topics = copy.Slice(d.Log.Topics) + data.Log.Topics = slices.Copy(d.Log.Topics) - data.Host.Name = copy.Slice(d.Host.Name) + data.Host.Name = slices.Copy(d.Host.Name) - data.API.Access.HTTP.Allow = copy.Slice(d.API.Access.HTTP.Allow) - data.API.Access.HTTP.Block = copy.Slice(d.API.Access.HTTP.Block) - data.API.Access.HTTPS.Allow = copy.Slice(d.API.Access.HTTPS.Allow) - data.API.Access.HTTPS.Block = copy.Slice(d.API.Access.HTTPS.Block) + data.API.Access.HTTP.Allow = slices.Copy(d.API.Access.HTTP.Allow) + data.API.Access.HTTP.Block = slices.Copy(d.API.Access.HTTP.Block) + data.API.Access.HTTPS.Allow = slices.Copy(d.API.Access.HTTPS.Allow) + data.API.Access.HTTPS.Block = slices.Copy(d.API.Access.HTTPS.Block) data.API.Auth.Auth0.Tenants = copy.TenantSlice(d.API.Auth.Auth0.Tenants) - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) data.FFmpeg.Binary = d.FFmpeg.Binary data.FFmpeg.MaxProcesses = d.FFmpeg.MaxProcesses - data.FFmpeg.Access.Input.Allow = copy.Slice(d.FFmpeg.Access.Input.Allow) - data.FFmpeg.Access.Input.Block = copy.Slice(d.FFmpeg.Access.Input.Block) - data.FFmpeg.Access.Output.Allow = copy.Slice(d.FFmpeg.Access.Output.Allow) - data.FFmpeg.Access.Output.Block = copy.Slice(d.FFmpeg.Access.Output.Block) + data.FFmpeg.Access.Input.Allow = slices.Copy(d.FFmpeg.Access.Input.Allow) + data.FFmpeg.Access.Input.Block = slices.Copy(d.FFmpeg.Access.Input.Block) + data.FFmpeg.Access.Output.Allow = slices.Copy(d.FFmpeg.Access.Output.Allow) + data.FFmpeg.Access.Output.Block = slices.Copy(d.FFmpeg.Access.Output.Block) data.FFmpeg.Log.MaxLines = d.FFmpeg.Log.MaxLines data.FFmpeg.Log.MaxHistory = d.FFmpeg.Log.MaxHistory data.Sessions.Enable = d.Sessions.Enable - data.Sessions.IPIgnoreList = copy.Slice(d.Sessions.IPIgnoreList) + data.Sessions.IPIgnoreList = slices.Copy(d.Sessions.IPIgnoreList) data.Sessions.SessionTimeout = d.Sessions.SessionTimeout data.Sessions.Persist = d.Sessions.Persist data.Sessions.PersistInterval = d.Sessions.PersistInterval data.Sessions.MaxBitrate = d.Sessions.MaxBitrate data.Sessions.MaxSessions = d.Sessions.MaxSessions - data.SRT.Log.Topics = copy.Slice(d.SRT.Log.Topics) + data.SRT.Log.Topics = slices.Copy(d.SRT.Log.Topics) - data.Router.BlockedPrefixes = copy.Slice(d.Router.BlockedPrefixes) + data.Router.BlockedPrefixes = slices.Copy(d.Router.BlockedPrefixes) data.Router.Routes = copy.StringMap(d.Router.Routes) data.Storage.MimeTypes = d.Storage.MimeTypes data.Storage.CORS = d.Storage.CORS - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) data.Storage.Memory = d.Storage.Memory @@ -273,7 +275,7 @@ func MergeV2toV3(data *Data, d *v2.Data) (*Data, error) { data.Storage.Disk.Cache.Size = d.Storage.Disk.Cache.Size data.Storage.Disk.Cache.FileSize = d.Storage.Disk.Cache.FileSize data.Storage.Disk.Cache.TTL = d.Storage.Disk.Cache.TTL - data.Storage.Disk.Cache.Types.Allow = copy.Slice(d.Storage.Disk.Cache.Types) + data.Storage.Disk.Cache.Types.Allow = slices.Copy(d.Storage.Disk.Cache.Types) data.Storage.S3 = []value.S3Storage{} @@ -307,39 +309,39 @@ func DowngradeV3toV2(d *Data) (*v2.Data, error) { data.Service = d.Service data.Router = d.Router - data.Log.Topics = copy.Slice(d.Log.Topics) + data.Log.Topics = slices.Copy(d.Log.Topics) - data.Host.Name = copy.Slice(d.Host.Name) + data.Host.Name = slices.Copy(d.Host.Name) - data.API.Access.HTTP.Allow = copy.Slice(d.API.Access.HTTP.Allow) - data.API.Access.HTTP.Block = copy.Slice(d.API.Access.HTTP.Block) - data.API.Access.HTTPS.Allow = copy.Slice(d.API.Access.HTTPS.Allow) - data.API.Access.HTTPS.Block = copy.Slice(d.API.Access.HTTPS.Block) + data.API.Access.HTTP.Allow = slices.Copy(d.API.Access.HTTP.Allow) + data.API.Access.HTTP.Block = slices.Copy(d.API.Access.HTTP.Block) + data.API.Access.HTTPS.Allow = slices.Copy(d.API.Access.HTTPS.Allow) + data.API.Access.HTTPS.Block = slices.Copy(d.API.Access.HTTPS.Block) data.API.Auth.Auth0.Tenants = copy.TenantSlice(d.API.Auth.Auth0.Tenants) - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) data.FFmpeg.Binary = d.FFmpeg.Binary data.FFmpeg.MaxProcesses = d.FFmpeg.MaxProcesses - data.FFmpeg.Access.Input.Allow = copy.Slice(d.FFmpeg.Access.Input.Allow) - data.FFmpeg.Access.Input.Block = copy.Slice(d.FFmpeg.Access.Input.Block) - data.FFmpeg.Access.Output.Allow = copy.Slice(d.FFmpeg.Access.Output.Allow) - data.FFmpeg.Access.Output.Block = copy.Slice(d.FFmpeg.Access.Output.Block) + data.FFmpeg.Access.Input.Allow = slices.Copy(d.FFmpeg.Access.Input.Allow) + data.FFmpeg.Access.Input.Block = slices.Copy(d.FFmpeg.Access.Input.Block) + data.FFmpeg.Access.Output.Allow = slices.Copy(d.FFmpeg.Access.Output.Allow) + data.FFmpeg.Access.Output.Block = slices.Copy(d.FFmpeg.Access.Output.Block) data.FFmpeg.Log.MaxLines = d.FFmpeg.Log.MaxLines data.FFmpeg.Log.MaxHistory = d.FFmpeg.Log.MaxHistory data.Sessions.Enable = d.Sessions.Enable - data.Sessions.IPIgnoreList = copy.Slice(d.Sessions.IPIgnoreList) + data.Sessions.IPIgnoreList = slices.Copy(d.Sessions.IPIgnoreList) data.Sessions.SessionTimeout = d.Sessions.SessionTimeout data.Sessions.Persist = d.Sessions.Persist data.Sessions.PersistInterval = d.Sessions.PersistInterval data.Sessions.MaxBitrate = d.Sessions.MaxBitrate data.Sessions.MaxSessions = d.Sessions.MaxSessions - data.SRT.Log.Topics = copy.Slice(d.SRT.Log.Topics) + data.SRT.Log.Topics = slices.Copy(d.SRT.Log.Topics) - data.Router.BlockedPrefixes = copy.Slice(d.Router.BlockedPrefixes) + data.Router.BlockedPrefixes = slices.Copy(d.Router.BlockedPrefixes) data.Router.Routes = copy.StringMap(d.Router.Routes) // Actual changes @@ -355,7 +357,7 @@ func DowngradeV3toV2(d *Data) (*v2.Data, error) { data.Storage.MimeTypes = d.Storage.MimeTypes data.Storage.CORS = d.Storage.CORS - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) data.Storage.Memory = d.Storage.Memory @@ -365,7 +367,7 @@ func DowngradeV3toV2(d *Data) (*v2.Data, error) { data.Storage.Disk.Cache.Size = d.Storage.Disk.Cache.Size data.Storage.Disk.Cache.FileSize = d.Storage.Disk.Cache.FileSize data.Storage.Disk.Cache.TTL = d.Storage.Disk.Cache.TTL - data.Storage.Disk.Cache.Types = copy.Slice(d.Storage.Disk.Cache.Types.Allow) + data.Storage.Disk.Cache.Types = slices.Copy(d.Storage.Disk.Cache.Types.Allow) data.Version = 2 diff --git a/config/v1/config.go b/config/v1/config.go index 022edfe9..b19639b0 100644 --- a/config/v1/config.go +++ b/config/v1/config.go @@ -10,6 +10,7 @@ import ( "github.com/datarhei/core/v16/config/vars" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/math/rand" + "github.com/datarhei/core/v16/slices" haikunator "github.com/atrox/haikunatorgo/v2" "github.com/google/uuid" @@ -81,30 +82,30 @@ func (d *Config) Clone() *Config { data.Service = d.Service data.Router = d.Router - data.Log.Topics = copy.Slice(d.Log.Topics) + data.Log.Topics = slices.Copy(d.Log.Topics) - data.Host.Name = copy.Slice(d.Host.Name) + data.Host.Name = slices.Copy(d.Host.Name) - data.API.Access.HTTP.Allow = copy.Slice(d.API.Access.HTTP.Allow) - data.API.Access.HTTP.Block = copy.Slice(d.API.Access.HTTP.Block) - data.API.Access.HTTPS.Allow = copy.Slice(d.API.Access.HTTPS.Allow) - data.API.Access.HTTPS.Block = copy.Slice(d.API.Access.HTTPS.Block) + data.API.Access.HTTP.Allow = slices.Copy(d.API.Access.HTTP.Allow) + data.API.Access.HTTP.Block = slices.Copy(d.API.Access.HTTP.Block) + data.API.Access.HTTPS.Allow = slices.Copy(d.API.Access.HTTPS.Allow) + data.API.Access.HTTPS.Block = slices.Copy(d.API.Access.HTTPS.Block) data.API.Auth.Auth0.Tenants = copy.TenantSlice(d.API.Auth.Auth0.Tenants) - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) - data.Storage.Disk.Cache.Types = copy.Slice(d.Storage.Disk.Cache.Types) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) + data.Storage.Disk.Cache.Types = slices.Copy(d.Storage.Disk.Cache.Types) - data.FFmpeg.Access.Input.Allow = copy.Slice(d.FFmpeg.Access.Input.Allow) - data.FFmpeg.Access.Input.Block = copy.Slice(d.FFmpeg.Access.Input.Block) - data.FFmpeg.Access.Output.Allow = copy.Slice(d.FFmpeg.Access.Output.Allow) - data.FFmpeg.Access.Output.Block = copy.Slice(d.FFmpeg.Access.Output.Block) + data.FFmpeg.Access.Input.Allow = slices.Copy(d.FFmpeg.Access.Input.Allow) + data.FFmpeg.Access.Input.Block = slices.Copy(d.FFmpeg.Access.Input.Block) + data.FFmpeg.Access.Output.Allow = slices.Copy(d.FFmpeg.Access.Output.Allow) + data.FFmpeg.Access.Output.Block = slices.Copy(d.FFmpeg.Access.Output.Block) - data.Sessions.IPIgnoreList = copy.Slice(d.Sessions.IPIgnoreList) + data.Sessions.IPIgnoreList = slices.Copy(d.Sessions.IPIgnoreList) - data.SRT.Log.Topics = copy.Slice(d.SRT.Log.Topics) + data.SRT.Log.Topics = slices.Copy(d.SRT.Log.Topics) - data.Router.BlockedPrefixes = copy.Slice(d.Router.BlockedPrefixes) + data.Router.BlockedPrefixes = slices.Copy(d.Router.BlockedPrefixes) data.Router.Routes = copy.StringMap(d.Router.Routes) data.vars.Transfer(&d.vars) diff --git a/config/v2/config.go b/config/v2/config.go index e1bfb0cb..46c6c4b0 100644 --- a/config/v2/config.go +++ b/config/v2/config.go @@ -10,6 +10,7 @@ import ( "github.com/datarhei/core/v16/config/vars" "github.com/datarhei/core/v16/io/fs" "github.com/datarhei/core/v16/math/rand" + "github.com/datarhei/core/v16/slices" haikunator "github.com/atrox/haikunatorgo/v2" "github.com/google/uuid" @@ -81,30 +82,30 @@ func (d *Config) Clone() *Config { data.Service = d.Service data.Router = d.Router - data.Log.Topics = copy.Slice(d.Log.Topics) + data.Log.Topics = slices.Copy(d.Log.Topics) - data.Host.Name = copy.Slice(d.Host.Name) + data.Host.Name = slices.Copy(d.Host.Name) - data.API.Access.HTTP.Allow = copy.Slice(d.API.Access.HTTP.Allow) - data.API.Access.HTTP.Block = copy.Slice(d.API.Access.HTTP.Block) - data.API.Access.HTTPS.Allow = copy.Slice(d.API.Access.HTTPS.Allow) - data.API.Access.HTTPS.Block = copy.Slice(d.API.Access.HTTPS.Block) + data.API.Access.HTTP.Allow = slices.Copy(d.API.Access.HTTP.Allow) + data.API.Access.HTTP.Block = slices.Copy(d.API.Access.HTTP.Block) + data.API.Access.HTTPS.Allow = slices.Copy(d.API.Access.HTTPS.Allow) + data.API.Access.HTTPS.Block = slices.Copy(d.API.Access.HTTPS.Block) data.API.Auth.Auth0.Tenants = copy.TenantSlice(d.API.Auth.Auth0.Tenants) - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) - data.Storage.Disk.Cache.Types = copy.Slice(d.Storage.Disk.Cache.Types) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) + data.Storage.Disk.Cache.Types = slices.Copy(d.Storage.Disk.Cache.Types) - data.FFmpeg.Access.Input.Allow = copy.Slice(d.FFmpeg.Access.Input.Allow) - data.FFmpeg.Access.Input.Block = copy.Slice(d.FFmpeg.Access.Input.Block) - data.FFmpeg.Access.Output.Allow = copy.Slice(d.FFmpeg.Access.Output.Allow) - data.FFmpeg.Access.Output.Block = copy.Slice(d.FFmpeg.Access.Output.Block) + data.FFmpeg.Access.Input.Allow = slices.Copy(d.FFmpeg.Access.Input.Allow) + data.FFmpeg.Access.Input.Block = slices.Copy(d.FFmpeg.Access.Input.Block) + data.FFmpeg.Access.Output.Allow = slices.Copy(d.FFmpeg.Access.Output.Allow) + data.FFmpeg.Access.Output.Block = slices.Copy(d.FFmpeg.Access.Output.Block) - data.Sessions.IPIgnoreList = copy.Slice(d.Sessions.IPIgnoreList) + data.Sessions.IPIgnoreList = slices.Copy(d.Sessions.IPIgnoreList) - data.SRT.Log.Topics = copy.Slice(d.SRT.Log.Topics) + data.SRT.Log.Topics = slices.Copy(d.SRT.Log.Topics) - data.Router.BlockedPrefixes = copy.Slice(d.Router.BlockedPrefixes) + data.Router.BlockedPrefixes = slices.Copy(d.Router.BlockedPrefixes) data.Router.Routes = copy.StringMap(d.Router.Routes) data.vars.Transfer(&d.vars) diff --git a/config/v2/data.go b/config/v2/data.go index 1c226376..e6acc85a 100644 --- a/config/v2/data.go +++ b/config/v2/data.go @@ -11,6 +11,7 @@ import ( v1 "github.com/datarhei/core/v16/config/v1" "github.com/datarhei/core/v16/config/value" "github.com/datarhei/core/v16/io/fs" + "github.com/datarhei/core/v16/slices" ) type Data struct { @@ -198,29 +199,29 @@ func MergeV1ToV2(data *Data, d *v1.Data) (*Data, error) { data.Service = d.Service data.Router = d.Router - data.Log.Topics = copy.Slice(d.Log.Topics) + data.Log.Topics = slices.Copy(d.Log.Topics) - data.Host.Name = copy.Slice(d.Host.Name) + data.Host.Name = slices.Copy(d.Host.Name) - data.API.Access.HTTP.Allow = copy.Slice(d.API.Access.HTTP.Allow) - data.API.Access.HTTP.Block = copy.Slice(d.API.Access.HTTP.Block) - data.API.Access.HTTPS.Allow = copy.Slice(d.API.Access.HTTPS.Allow) - data.API.Access.HTTPS.Block = copy.Slice(d.API.Access.HTTPS.Block) + data.API.Access.HTTP.Allow = slices.Copy(d.API.Access.HTTP.Allow) + data.API.Access.HTTP.Block = slices.Copy(d.API.Access.HTTP.Block) + data.API.Access.HTTPS.Allow = slices.Copy(d.API.Access.HTTPS.Allow) + data.API.Access.HTTPS.Block = slices.Copy(d.API.Access.HTTPS.Block) data.API.Auth.Auth0.Tenants = copy.TenantSlice(d.API.Auth.Auth0.Tenants) - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) - data.FFmpeg.Access.Input.Allow = copy.Slice(d.FFmpeg.Access.Input.Allow) - data.FFmpeg.Access.Input.Block = copy.Slice(d.FFmpeg.Access.Input.Block) - data.FFmpeg.Access.Output.Allow = copy.Slice(d.FFmpeg.Access.Output.Allow) - data.FFmpeg.Access.Output.Block = copy.Slice(d.FFmpeg.Access.Output.Block) + data.FFmpeg.Access.Input.Allow = slices.Copy(d.FFmpeg.Access.Input.Allow) + data.FFmpeg.Access.Input.Block = slices.Copy(d.FFmpeg.Access.Input.Block) + data.FFmpeg.Access.Output.Allow = slices.Copy(d.FFmpeg.Access.Output.Allow) + data.FFmpeg.Access.Output.Block = slices.Copy(d.FFmpeg.Access.Output.Block) - data.Sessions.IPIgnoreList = copy.Slice(d.Sessions.IPIgnoreList) + data.Sessions.IPIgnoreList = slices.Copy(d.Sessions.IPIgnoreList) - data.SRT.Log.Topics = copy.Slice(d.SRT.Log.Topics) + data.SRT.Log.Topics = slices.Copy(d.SRT.Log.Topics) - data.Router.BlockedPrefixes = copy.Slice(d.Router.BlockedPrefixes) + data.Router.BlockedPrefixes = slices.Copy(d.Router.BlockedPrefixes) data.Router.Routes = copy.StringMap(d.Router.Routes) // Actual changes @@ -282,29 +283,29 @@ func DowngradeV2toV1(d *Data) (*v1.Data, error) { data.Service = d.Service data.Router = d.Router - data.Log.Topics = copy.Slice(d.Log.Topics) + data.Log.Topics = slices.Copy(d.Log.Topics) - data.Host.Name = copy.Slice(d.Host.Name) + data.Host.Name = slices.Copy(d.Host.Name) - data.API.Access.HTTP.Allow = copy.Slice(d.API.Access.HTTP.Allow) - data.API.Access.HTTP.Block = copy.Slice(d.API.Access.HTTP.Block) - data.API.Access.HTTPS.Allow = copy.Slice(d.API.Access.HTTPS.Allow) - data.API.Access.HTTPS.Block = copy.Slice(d.API.Access.HTTPS.Block) + data.API.Access.HTTP.Allow = slices.Copy(d.API.Access.HTTP.Allow) + data.API.Access.HTTP.Block = slices.Copy(d.API.Access.HTTP.Block) + data.API.Access.HTTPS.Allow = slices.Copy(d.API.Access.HTTPS.Allow) + data.API.Access.HTTPS.Block = slices.Copy(d.API.Access.HTTPS.Block) data.API.Auth.Auth0.Tenants = copy.TenantSlice(d.API.Auth.Auth0.Tenants) - data.Storage.CORS.Origins = copy.Slice(d.Storage.CORS.Origins) + data.Storage.CORS.Origins = slices.Copy(d.Storage.CORS.Origins) - data.FFmpeg.Access.Input.Allow = copy.Slice(d.FFmpeg.Access.Input.Allow) - data.FFmpeg.Access.Input.Block = copy.Slice(d.FFmpeg.Access.Input.Block) - data.FFmpeg.Access.Output.Allow = copy.Slice(d.FFmpeg.Access.Output.Allow) - data.FFmpeg.Access.Output.Block = copy.Slice(d.FFmpeg.Access.Output.Block) + data.FFmpeg.Access.Input.Allow = slices.Copy(d.FFmpeg.Access.Input.Allow) + data.FFmpeg.Access.Input.Block = slices.Copy(d.FFmpeg.Access.Input.Block) + data.FFmpeg.Access.Output.Allow = slices.Copy(d.FFmpeg.Access.Output.Allow) + data.FFmpeg.Access.Output.Block = slices.Copy(d.FFmpeg.Access.Output.Block) - data.Sessions.IPIgnoreList = copy.Slice(d.Sessions.IPIgnoreList) + data.Sessions.IPIgnoreList = slices.Copy(d.Sessions.IPIgnoreList) - data.SRT.Log.Topics = copy.Slice(d.SRT.Log.Topics) + data.SRT.Log.Topics = slices.Copy(d.SRT.Log.Topics) - data.Router.BlockedPrefixes = copy.Slice(d.Router.BlockedPrefixes) + data.Router.BlockedPrefixes = slices.Copy(d.Router.BlockedPrefixes) data.Router.Routes = copy.StringMap(d.Router.Routes) // Actual changes diff --git a/slices/copy.go b/slices/copy.go new file mode 100644 index 00000000..3a46a771 --- /dev/null +++ b/slices/copy.go @@ -0,0 +1,8 @@ +package slices + +func Copy[T any](src []T) []T { + dst := make([]T, len(src)) + copy(dst, src) + + return dst +} diff --git a/slices/copy_test.go b/slices/copy_test.go new file mode 100644 index 00000000..ebc8a448 --- /dev/null +++ b/slices/copy_test.go @@ -0,0 +1,15 @@ +package slices + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCopy(t *testing.T) { + a := []string{"a", "b", "c"} + + b := Copy(a) + + require.Equal(t, []string{"a", "b", "c"}, b) +} diff --git a/slices/diff.go b/slices/diff.go new file mode 100644 index 00000000..64a499c1 --- /dev/null +++ b/slices/diff.go @@ -0,0 +1,28 @@ +package slices + +// Diff returns a sliceof newly added entries and a slice of removed entries based +// the provided slices. +func Diff[T comparable](next, current []T) ([]T, []T) { + added, removed := []T{}, []T{} + + currentMap := map[T]struct{}{} + + for _, name := range current { + currentMap[name] = struct{}{} + } + + for _, name := range next { + if _, ok := currentMap[name]; ok { + delete(currentMap, name) + continue + } + + added = append(added, name) + } + + for name := range currentMap { + removed = append(removed, name) + } + + return added, removed +} diff --git a/slices/diff_test.go b/slices/diff_test.go new file mode 100644 index 00000000..b041b4b8 --- /dev/null +++ b/slices/diff_test.go @@ -0,0 +1,17 @@ +package slices + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDiff(t *testing.T) { + a := []string{"c", "d", "e", "f"} + b := []string{"a", "b", "c", "d"} + + added, removed := Diff(a, b) + + require.Equal(t, []string{"e", "f"}, added) + require.Equal(t, []string{"a", "b"}, removed) +} diff --git a/vendor/golang.org/x/crypto/scrypt/scrypt.go b/vendor/golang.org/x/crypto/scrypt/scrypt.go new file mode 100644 index 00000000..c971a99f --- /dev/null +++ b/vendor/golang.org/x/crypto/scrypt/scrypt.go @@ -0,0 +1,212 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package scrypt implements the scrypt key derivation function as defined in +// Colin Percival's paper "Stronger Key Derivation via Sequential Memory-Hard +// Functions" (https://www.tarsnap.com/scrypt/scrypt.pdf). +package scrypt // import "golang.org/x/crypto/scrypt" + +import ( + "crypto/sha256" + "encoding/binary" + "errors" + "math/bits" + + "golang.org/x/crypto/pbkdf2" +) + +const maxInt = int(^uint(0) >> 1) + +// blockCopy copies n numbers from src into dst. +func blockCopy(dst, src []uint32, n int) { + copy(dst, src[:n]) +} + +// blockXOR XORs numbers from dst with n numbers from src. +func blockXOR(dst, src []uint32, n int) { + for i, v := range src[:n] { + dst[i] ^= v + } +} + +// salsaXOR applies Salsa20/8 to the XOR of 16 numbers from tmp and in, +// and puts the result into both tmp and out. +func salsaXOR(tmp *[16]uint32, in, out []uint32) { + w0 := tmp[0] ^ in[0] + w1 := tmp[1] ^ in[1] + w2 := tmp[2] ^ in[2] + w3 := tmp[3] ^ in[3] + w4 := tmp[4] ^ in[4] + w5 := tmp[5] ^ in[5] + w6 := tmp[6] ^ in[6] + w7 := tmp[7] ^ in[7] + w8 := tmp[8] ^ in[8] + w9 := tmp[9] ^ in[9] + w10 := tmp[10] ^ in[10] + w11 := tmp[11] ^ in[11] + w12 := tmp[12] ^ in[12] + w13 := tmp[13] ^ in[13] + w14 := tmp[14] ^ in[14] + w15 := tmp[15] ^ in[15] + + x0, x1, x2, x3, x4, x5, x6, x7, x8 := w0, w1, w2, w3, w4, w5, w6, w7, w8 + x9, x10, x11, x12, x13, x14, x15 := w9, w10, w11, w12, w13, w14, w15 + + for i := 0; i < 8; i += 2 { + x4 ^= bits.RotateLeft32(x0+x12, 7) + x8 ^= bits.RotateLeft32(x4+x0, 9) + x12 ^= bits.RotateLeft32(x8+x4, 13) + x0 ^= bits.RotateLeft32(x12+x8, 18) + + x9 ^= bits.RotateLeft32(x5+x1, 7) + x13 ^= bits.RotateLeft32(x9+x5, 9) + x1 ^= bits.RotateLeft32(x13+x9, 13) + x5 ^= bits.RotateLeft32(x1+x13, 18) + + x14 ^= bits.RotateLeft32(x10+x6, 7) + x2 ^= bits.RotateLeft32(x14+x10, 9) + x6 ^= bits.RotateLeft32(x2+x14, 13) + x10 ^= bits.RotateLeft32(x6+x2, 18) + + x3 ^= bits.RotateLeft32(x15+x11, 7) + x7 ^= bits.RotateLeft32(x3+x15, 9) + x11 ^= bits.RotateLeft32(x7+x3, 13) + x15 ^= bits.RotateLeft32(x11+x7, 18) + + x1 ^= bits.RotateLeft32(x0+x3, 7) + x2 ^= bits.RotateLeft32(x1+x0, 9) + x3 ^= bits.RotateLeft32(x2+x1, 13) + x0 ^= bits.RotateLeft32(x3+x2, 18) + + x6 ^= bits.RotateLeft32(x5+x4, 7) + x7 ^= bits.RotateLeft32(x6+x5, 9) + x4 ^= bits.RotateLeft32(x7+x6, 13) + x5 ^= bits.RotateLeft32(x4+x7, 18) + + x11 ^= bits.RotateLeft32(x10+x9, 7) + x8 ^= bits.RotateLeft32(x11+x10, 9) + x9 ^= bits.RotateLeft32(x8+x11, 13) + x10 ^= bits.RotateLeft32(x9+x8, 18) + + x12 ^= bits.RotateLeft32(x15+x14, 7) + x13 ^= bits.RotateLeft32(x12+x15, 9) + x14 ^= bits.RotateLeft32(x13+x12, 13) + x15 ^= bits.RotateLeft32(x14+x13, 18) + } + x0 += w0 + x1 += w1 + x2 += w2 + x3 += w3 + x4 += w4 + x5 += w5 + x6 += w6 + x7 += w7 + x8 += w8 + x9 += w9 + x10 += w10 + x11 += w11 + x12 += w12 + x13 += w13 + x14 += w14 + x15 += w15 + + out[0], tmp[0] = x0, x0 + out[1], tmp[1] = x1, x1 + out[2], tmp[2] = x2, x2 + out[3], tmp[3] = x3, x3 + out[4], tmp[4] = x4, x4 + out[5], tmp[5] = x5, x5 + out[6], tmp[6] = x6, x6 + out[7], tmp[7] = x7, x7 + out[8], tmp[8] = x8, x8 + out[9], tmp[9] = x9, x9 + out[10], tmp[10] = x10, x10 + out[11], tmp[11] = x11, x11 + out[12], tmp[12] = x12, x12 + out[13], tmp[13] = x13, x13 + out[14], tmp[14] = x14, x14 + out[15], tmp[15] = x15, x15 +} + +func blockMix(tmp *[16]uint32, in, out []uint32, r int) { + blockCopy(tmp[:], in[(2*r-1)*16:], 16) + for i := 0; i < 2*r; i += 2 { + salsaXOR(tmp, in[i*16:], out[i*8:]) + salsaXOR(tmp, in[i*16+16:], out[i*8+r*16:]) + } +} + +func integer(b []uint32, r int) uint64 { + j := (2*r - 1) * 16 + return uint64(b[j]) | uint64(b[j+1])<<32 +} + +func smix(b []byte, r, N int, v, xy []uint32) { + var tmp [16]uint32 + R := 32 * r + x := xy + y := xy[R:] + + j := 0 + for i := 0; i < R; i++ { + x[i] = binary.LittleEndian.Uint32(b[j:]) + j += 4 + } + for i := 0; i < N; i += 2 { + blockCopy(v[i*R:], x, R) + blockMix(&tmp, x, y, r) + + blockCopy(v[(i+1)*R:], y, R) + blockMix(&tmp, y, x, r) + } + for i := 0; i < N; i += 2 { + j := int(integer(x, r) & uint64(N-1)) + blockXOR(x, v[j*R:], R) + blockMix(&tmp, x, y, r) + + j = int(integer(y, r) & uint64(N-1)) + blockXOR(y, v[j*R:], R) + blockMix(&tmp, y, x, r) + } + j = 0 + for _, v := range x[:R] { + binary.LittleEndian.PutUint32(b[j:], v) + j += 4 + } +} + +// Key derives a key from the password, salt, and cost parameters, returning +// a byte slice of length keyLen that can be used as cryptographic key. +// +// N is a CPU/memory cost parameter, which must be a power of two greater than 1. +// r and p must satisfy r * p < 2³⁰. If the parameters do not satisfy the +// limits, the function returns a nil byte slice and an error. +// +// For example, you can get a derived key for e.g. AES-256 (which needs a +// 32-byte key) by doing: +// +// dk, err := scrypt.Key([]byte("some password"), salt, 32768, 8, 1, 32) +// +// The recommended parameters for interactive logins as of 2017 are N=32768, r=8 +// and p=1. The parameters N, r, and p should be increased as memory latency and +// CPU parallelism increases; consider setting N to the highest power of 2 you +// can derive within 100 milliseconds. Remember to get a good random salt. +func Key(password, salt []byte, N, r, p, keyLen int) ([]byte, error) { + if N <= 1 || N&(N-1) != 0 { + return nil, errors.New("scrypt: N must be > 1 and a power of 2") + } + if uint64(r)*uint64(p) >= 1<<30 || r > maxInt/128/p || r > maxInt/256 || N > maxInt/128/r { + return nil, errors.New("scrypt: parameters are too large") + } + + xy := make([]uint32, 64*r) + v := make([]uint32, 32*N*r) + b := pbkdf2.Key(password, salt, 1, p*128*r, sha256.New) + + for i := 0; i < p; i++ { + smix(b[i*128*r:], r, N, v, xy) + } + + return pbkdf2.Key(password, b, 1, keyLen, sha256.New), nil +} diff --git a/vendor/modules.txt b/vendor/modules.txt index f8258781..cbe43399 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -439,6 +439,7 @@ golang.org/x/crypto/cryptobyte golang.org/x/crypto/cryptobyte/asn1 golang.org/x/crypto/ocsp golang.org/x/crypto/pbkdf2 +golang.org/x/crypto/scrypt golang.org/x/crypto/sha3 # golang.org/x/mod v0.11.0 ## explicit; go 1.17