mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-11-03 02:23:49 +08:00
Use Atomic instead of RWMutex for Hooks concurrency (#148)
* Use Atomic instead of RWMutex for Hooks concurrency * Lock Hooks on Add Hook
This commit is contained in:
30
hooks.go
30
hooks.go
@@ -109,11 +109,11 @@ type HookOptions struct {
|
|||||||
|
|
||||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||||
type Hooks struct {
|
type Hooks struct {
|
||||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||||
internal []Hook // a slice of hooks
|
internal atomic.Value // a slice of []Hook
|
||||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||||
qty int64 // the number of hooks in use
|
qty int64 // the number of hooks in use
|
||||||
sync.RWMutex // a mutex
|
sync.Mutex // a mutex for locking when adding hooks
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the number of hooks added.
|
// Len returns the number of hooks added.
|
||||||
@@ -138,16 +138,19 @@ func (h *Hooks) Provides(b ...byte) bool {
|
|||||||
func (h *Hooks) Add(hook Hook, config any) error {
|
func (h *Hooks) Add(hook Hook, config any) error {
|
||||||
h.Lock()
|
h.Lock()
|
||||||
defer h.Unlock()
|
defer h.Unlock()
|
||||||
if h.internal == nil {
|
|
||||||
h.internal = []Hook{}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := hook.Init(config)
|
err := hook.Init(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.internal = append(h.internal, hook)
|
i, ok := h.internal.Load().([]Hook)
|
||||||
|
if !ok {
|
||||||
|
i = []Hook{}
|
||||||
|
}
|
||||||
|
|
||||||
|
i = append(i, hook)
|
||||||
|
h.internal.Store(i)
|
||||||
atomic.AddInt64(&h.qty, 1)
|
atomic.AddInt64(&h.qty, 1)
|
||||||
h.wg.Add(1)
|
h.wg.Add(1)
|
||||||
|
|
||||||
@@ -156,9 +159,12 @@ func (h *Hooks) Add(hook Hook, config any) error {
|
|||||||
|
|
||||||
// GetAll returns a slice of all the hooks.
|
// GetAll returns a slice of all the hooks.
|
||||||
func (h *Hooks) GetAll() []Hook {
|
func (h *Hooks) GetAll() []Hook {
|
||||||
h.RLock()
|
i, ok := h.internal.Load().([]Hook)
|
||||||
defer h.RUnlock()
|
if !ok {
|
||||||
return append([]Hook{}, h.internal...)
|
return []Hook{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop indicates all attached hooks to gracefully end.
|
// Stop indicates all attached hooks to gracefully end.
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ type modifiedHookBase struct {
|
|||||||
|
|
||||||
var errTestHook = errors.New("error")
|
var errTestHook = errors.New("error")
|
||||||
|
|
||||||
|
func (h *modifiedHookBase) ID() string {
|
||||||
|
return "modified"
|
||||||
|
}
|
||||||
|
|
||||||
func (h *modifiedHookBase) Init(config any) error {
|
func (h *modifiedHookBase) Init(config any) error {
|
||||||
if config != nil {
|
if config != nil {
|
||||||
return errTestHook
|
return errTestHook
|
||||||
@@ -178,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
|
|||||||
require.False(t, h.Provides(OnDisconnect))
|
require.False(t, h.Provides(OnDisconnect))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHooksAddAndLen(t *testing.T) {
|
func TestHooksAddLenGetAll(t *testing.T) {
|
||||||
h := new(Hooks)
|
h := new(Hooks)
|
||||||
err := h.Add(new(HookBase), nil)
|
err := h.Add(new(HookBase), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
|
|
||||||
require.Equal(t, int64(1), h.Len())
|
err = h.Add(new(modifiedHookBase), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
|
||||||
|
require.Equal(t, int64(2), h.Len())
|
||||||
|
|
||||||
|
all := h.GetAll()
|
||||||
|
require.Equal(t, "base", all[0].ID())
|
||||||
|
require.Equal(t, "modified", all[1].ID())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHooksAddInitFailure(t *testing.T) {
|
func TestHooksAddInitFailure(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user