Extend onsusbcribe, onunsubscribe events

This commit is contained in:
mochi
2022-05-04 12:53:04 +01:00
parent 9b5cdb0bcc
commit 27f3c484ad
3 changed files with 48 additions and 24 deletions

View File

@@ -6,13 +6,13 @@ import (
// Events provides callback handlers for different event hooks. // Events provides callback handlers for different event hooks.
type Events struct { type Events struct {
OnProcessMessage // published message receieved before evaluation. OnProcessMessage // published message receieved before evaluation.
OnMessage // published message receieved. OnMessage // published message receieved.
OnError // server error. OnError // server error.
OnConnect // client connected. OnConnect // client connected.
OnDisconnect // client disconnected. OnDisconnect // client disconnected.
OnTopicSubscribe // topic subscription created. OnSubscribe // topic subscription created.
OnTopicUnsubscribe // topic subscription removed. OnUnsubscribe // topic subscription removed.
} }
// Packets is an alias for packets.Packet. // Packets is an alias for packets.Packet.
@@ -69,8 +69,8 @@ type OnDisconnect func(Client, error)
// OnDisconnect are handled by the server. // OnDisconnect are handled by the server.
type OnError func(Client, error) type OnError func(Client, error)
// OnTopicSubscribe is called when a new subscription filter for a client is created. // OnSubscribe is called when a new subscription filter for a client is created.
type OnTopicSubscribe func(filter string, client string, qos byte) type OnSubscribe func(filter string, cl Client, qos byte)
// OnTopicUnsubscribe is called when an existing subscription filter for a client is removed. // OnUnsubscribe is called when an existing subscription filter for a client is removed.
type OnTopicUnsubscribe func(filter string, client string) type OnUnsubscribe func(filter string, cl Client)

View File

@@ -400,8 +400,8 @@ func (s *Server) unsubscribeClient(cl *clients.Client) {
for k := range cl.Subscriptions { for k := range cl.Subscriptions {
delete(cl.Subscriptions, k) delete(cl.Subscriptions, k)
if s.Topics.Unsubscribe(k, cl.ID) { if s.Topics.Unsubscribe(k, cl.ID) {
if s.Events.OnTopicUnsubscribe != nil { if s.Events.OnUnsubscribe != nil {
s.Events.OnTopicUnsubscribe(k, cl.ID) s.Events.OnUnsubscribe(k, cl.Info())
} }
atomic.AddInt64(&s.System.Subscriptions, -1) atomic.AddInt64(&s.System.Subscriptions, -1)
} }
@@ -738,10 +738,10 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
if !cl.AC.ACL(cl.Username, pk.Topics[i], false) { if !cl.AC.ACL(cl.Username, pk.Topics[i], false) {
retCodes[i] = packets.ErrSubAckNetworkError retCodes[i] = packets.ErrSubAckNetworkError
} else { } else {
q := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i]) r := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
if q { if r {
if s.Events.OnTopicSubscribe != nil { if s.Events.OnSubscribe != nil {
s.Events.OnTopicSubscribe(pk.Topics[i], cl.ID, pk.Qoss[i]) s.Events.OnSubscribe(pk.Topics[i], cl.Info(), pk.Qoss[i])
} }
atomic.AddInt64(&s.System.Subscriptions, 1) atomic.AddInt64(&s.System.Subscriptions, 1)
} }
@@ -791,8 +791,8 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
for i := 0; i < len(pk.Topics); i++ { for i := 0; i < len(pk.Topics); i++ {
q := s.Topics.Unsubscribe(pk.Topics[i], cl.ID) q := s.Topics.Unsubscribe(pk.Topics[i], cl.ID)
if q { if q {
if s.Events.OnTopicUnsubscribe != nil { if s.Events.OnUnsubscribe != nil {
s.Events.OnTopicUnsubscribe(pk.Topics[i], cl.ID) s.Events.OnUnsubscribe(pk.Topics[i], cl.Info())
} }
atomic.AddInt64(&s.System.Subscriptions, -1) atomic.AddInt64(&s.System.Subscriptions, -1)
} }
@@ -1014,13 +1014,13 @@ func (s *Server) loadServerInfo(v persistence.ServerInfo) {
func (s *Server) loadSubscriptions(v []persistence.Subscription) { func (s *Server) loadSubscriptions(v []persistence.Subscription) {
for _, sub := range v { for _, sub := range v {
if s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) { if s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) {
if s.Events.OnTopicSubscribe != nil { if cl, ok := s.Clients.Get(sub.Client); ok {
s.Events.OnTopicSubscribe(sub.Filter, sub.Client, sub.QoS) cl.NoteSubscription(sub.Filter, sub.QoS)
if s.Events.OnSubscribe != nil {
s.Events.OnSubscribe(sub.Filter, cl.Info(), sub.QoS)
}
} }
} }
if cl, ok := s.Clients.Get(sub.Client); ok {
cl.NoteSubscription(sub.Filter, sub.QoS)
}
} }
} }

View File

@@ -1946,6 +1946,15 @@ func TestServerProcessSubscribeInvalid(t *testing.T) {
func TestServerProcessSubscribe(t *testing.T) { func TestServerProcessSubscribe(t *testing.T) {
s, cl, r, w := setupClient() s, cl, r, w := setupClient()
subscribeEvent := ""
subscribeClient := ""
s.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
if filter == "a/b/c" {
subscribeEvent = "a/b/c"
subscribeClient = cl.ID
}
}
s.Topics.RetainMessage(packets.Packet{ s.Topics.RetainMessage(packets.Packet{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
Type: packets.Publish, Type: packets.Publish,
@@ -1995,6 +2004,8 @@ func TestServerProcessSubscribe(t *testing.T) {
require.Equal(t, byte(1), cl.Subscriptions["d/e/f"]) require.Equal(t, byte(1), cl.Subscriptions["d/e/f"])
require.Equal(t, topics.Subscriptions{cl.ID: 0}, s.Topics.Subscribers("a/b/c")) require.Equal(t, topics.Subscriptions{cl.ID: 0}, s.Topics.Subscribers("a/b/c"))
require.Equal(t, topics.Subscriptions{cl.ID: 1}, s.Topics.Subscribers("d/e/f")) require.Equal(t, topics.Subscriptions{cl.ID: 1}, s.Topics.Subscribers("d/e/f"))
require.Equal(t, "a/b/c", subscribeEvent)
require.Equal(t, cl.ID, subscribeClient)
} }
func TestServerProcessSubscribeFailACL(t *testing.T) { func TestServerProcessSubscribeFailACL(t *testing.T) {
@@ -2114,6 +2125,16 @@ func TestServerProcessUnsubscribeInvalid(t *testing.T) {
func TestServerProcessUnsubscribe(t *testing.T) { func TestServerProcessUnsubscribe(t *testing.T) {
s, cl, r, w := setupClient() s, cl, r, w := setupClient()
unsubscribeEvent := ""
unsubscribeClient := ""
s.Events.OnUnsubscribe = func(filter string, cl events.Client) {
if filter == "a/b/c" {
unsubscribeEvent = "a/b/c"
unsubscribeClient = cl.ID
}
}
s.Clients.Add(cl) s.Clients.Add(cl)
s.Topics.Subscribe("a/b/c", cl.ID, 0) s.Topics.Subscribe("a/b/c", cl.ID, 0)
s.Topics.Subscribe("d/e/f", cl.ID, 1) s.Topics.Subscribe("d/e/f", cl.ID, 1)
@@ -2155,6 +2176,9 @@ func TestServerProcessUnsubscribe(t *testing.T) {
require.NotEmpty(t, s.Topics.Subscribers("a/b/+")) require.NotEmpty(t, s.Topics.Subscribers("a/b/+"))
require.Contains(t, cl.Subscriptions, "a/b/+") require.Contains(t, cl.Subscriptions, "a/b/+")
require.Equal(t, "a/b/c", unsubscribeEvent)
require.Equal(t, cl.ID, unsubscribeClient)
} }
func TestServerProcessUnsubscribeWriteError(t *testing.T) { func TestServerProcessUnsubscribeWriteError(t *testing.T) {