Use Atomic instead of RWMutex for Hooks concurrency (#148)

* Use Atomic instead of RWMutex for Hooks concurrency
* Lock Hooks on Add Hook
This commit is contained in:
JB
2023-01-16 19:49:36 +00:00
committed by GitHub
parent 300152413c
commit 0d79f2d63b
2 changed files with 33 additions and 15 deletions

View File

@@ -110,10 +110,10 @@ type HookOptions struct {
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.RWMutex // a mutex
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
@@ -138,16 +138,19 @@ func (h *Hooks) Provides(b ...byte) bool {
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
h.internal = append(h.internal, hook)
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
@@ -156,9 +159,12 @@ func (h *Hooks) Add(hook Hook, config any) error {
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
h.RLock()
defer h.RUnlock()
return append([]Hook{}, h.internal...)
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.

View File

@@ -27,6 +27,10 @@ type modifiedHookBase struct {
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
@@ -178,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddAndLen(t *testing.T) {
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {