Use Atomic instead of RWMutex for Hooks concurrency (#148)

* Use Atomic instead of RWMutex for Hooks concurrency
* Lock Hooks on Add Hook
This commit is contained in:
JB
2023-01-16 19:49:36 +00:00
committed by GitHub
parent 300152413c
commit 0d79f2d63b
2 changed files with 33 additions and 15 deletions

View File

@@ -109,11 +109,11 @@ type HookOptions struct {
// Hooks is a slice of Hook interfaces to be called in sequence. // Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct { type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server) Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use qty int64 // the number of hooks in use
sync.RWMutex // a mutex sync.Mutex // a mutex for locking when adding hooks
} }
// Len returns the number of hooks added. // Len returns the number of hooks added.
@@ -138,16 +138,19 @@ func (h *Hooks) Provides(b ...byte) bool {
func (h *Hooks) Add(hook Hook, config any) error { func (h *Hooks) Add(hook Hook, config any) error {
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config) err := hook.Init(config)
if err != nil { if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err) return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
} }
h.internal = append(h.internal, hook) i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1) atomic.AddInt64(&h.qty, 1)
h.wg.Add(1) h.wg.Add(1)
@@ -156,9 +159,12 @@ func (h *Hooks) Add(hook Hook, config any) error {
// GetAll returns a slice of all the hooks. // GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook { func (h *Hooks) GetAll() []Hook {
h.RLock() i, ok := h.internal.Load().([]Hook)
defer h.RUnlock() if !ok {
return append([]Hook{}, h.internal...) return []Hook{}
}
return i
} }
// Stop indicates all attached hooks to gracefully end. // Stop indicates all attached hooks to gracefully end.

View File

@@ -27,6 +27,10 @@ type modifiedHookBase struct {
var errTestHook = errors.New("error") var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error { func (h *modifiedHookBase) Init(config any) error {
if config != nil { if config != nil {
return errTestHook return errTestHook
@@ -178,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
require.False(t, h.Provides(OnDisconnect)) require.False(t, h.Provides(OnDisconnect))
} }
func TestHooksAddAndLen(t *testing.T) { func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks) h := new(Hooks)
err := h.Add(new(HookBase), nil) err := h.Add(new(HookBase), nil)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len()) err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
} }
func TestHooksAddInitFailure(t *testing.T) { func TestHooksAddInitFailure(t *testing.T) {