mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-11-03 10:31:11 +08:00
Use Atomic instead of RWMutex for Hooks concurrency (#148)
* Use Atomic instead of RWMutex for Hooks concurrency * Lock Hooks on Add Hook
This commit is contained in:
24
hooks.go
24
hooks.go
@@ -110,10 +110,10 @@ type HookOptions struct {
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal []Hook // a slice of hooks
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.RWMutex // a mutex
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
@@ -138,16 +138,19 @@ func (h *Hooks) Provides(b ...byte) bool {
|
||||
func (h *Hooks) Add(hook Hook, config any) error {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
if h.internal == nil {
|
||||
h.internal = []Hook{}
|
||||
}
|
||||
|
||||
err := hook.Init(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
||||
}
|
||||
|
||||
h.internal = append(h.internal, hook)
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
i = []Hook{}
|
||||
}
|
||||
|
||||
i = append(i, hook)
|
||||
h.internal.Store(i)
|
||||
atomic.AddInt64(&h.qty, 1)
|
||||
h.wg.Add(1)
|
||||
|
||||
@@ -156,9 +159,12 @@ func (h *Hooks) Add(hook Hook, config any) error {
|
||||
|
||||
// GetAll returns a slice of all the hooks.
|
||||
func (h *Hooks) GetAll() []Hook {
|
||||
h.RLock()
|
||||
defer h.RUnlock()
|
||||
return append([]Hook{}, h.internal...)
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
return []Hook{}
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Stop indicates all attached hooks to gracefully end.
|
||||
|
||||
@@ -27,6 +27,10 @@ type modifiedHookBase struct {
|
||||
|
||||
var errTestHook = errors.New("error")
|
||||
|
||||
func (h *modifiedHookBase) ID() string {
|
||||
return "modified"
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Init(config any) error {
|
||||
if config != nil {
|
||||
return errTestHook
|
||||
@@ -178,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHooksAddAndLen(t *testing.T) {
|
||||
func TestHooksAddLenGetAll(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(1), h.Len())
|
||||
|
||||
err = h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(2), h.Len())
|
||||
|
||||
all := h.GetAll()
|
||||
require.Equal(t, "base", all[0].ID())
|
||||
require.Equal(t, "modified", all[1].ID())
|
||||
}
|
||||
|
||||
func TestHooksAddInitFailure(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user