diff --git a/server/events/events.go b/server/events/events.go index b7394df..229a130 100644 --- a/server/events/events.go +++ b/server/events/events.go @@ -6,13 +6,13 @@ import ( // Events provides callback handlers for different event hooks. type Events struct { - OnProcessMessage // published message receieved before evaluation. - OnMessage // published message receieved. - OnError // server error. - OnConnect // client connected. - OnDisconnect // client disconnected. - OnTopicSubscribe // topic subscription created. - OnTopicUnsubscribe // topic subscription removed. + OnProcessMessage // published message receieved before evaluation. + OnMessage // published message receieved. + OnError // server error. + OnConnect // client connected. + OnDisconnect // client disconnected. + OnSubscribe // topic subscription created. + OnUnsubscribe // topic subscription removed. } // Packets is an alias for packets.Packet. @@ -69,8 +69,8 @@ type OnDisconnect func(Client, error) // OnDisconnect are handled by the server. type OnError func(Client, error) -// OnTopicSubscribe is called when a new subscription filter for a client is created. -type OnTopicSubscribe func(filter string, client string, qos byte) +// OnSubscribe is called when a new subscription filter for a client is created. +type OnSubscribe func(filter string, cl Client, qos byte) -// OnTopicUnsubscribe is called when an existing subscription filter for a client is removed. -type OnTopicUnsubscribe func(filter string, client string) +// OnUnsubscribe is called when an existing subscription filter for a client is removed. +type OnUnsubscribe func(filter string, cl Client) diff --git a/server/server.go b/server/server.go index 41c1088..82c1f65 100644 --- a/server/server.go +++ b/server/server.go @@ -400,8 +400,8 @@ func (s *Server) unsubscribeClient(cl *clients.Client) { for k := range cl.Subscriptions { delete(cl.Subscriptions, k) if s.Topics.Unsubscribe(k, cl.ID) { - if s.Events.OnTopicUnsubscribe != nil { - s.Events.OnTopicUnsubscribe(k, cl.ID) + if s.Events.OnUnsubscribe != nil { + s.Events.OnUnsubscribe(k, cl.Info()) } atomic.AddInt64(&s.System.Subscriptions, -1) } @@ -738,10 +738,10 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error { if !cl.AC.ACL(cl.Username, pk.Topics[i], false) { retCodes[i] = packets.ErrSubAckNetworkError } else { - q := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i]) - if q { - if s.Events.OnTopicSubscribe != nil { - s.Events.OnTopicSubscribe(pk.Topics[i], cl.ID, pk.Qoss[i]) + r := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i]) + if r { + if s.Events.OnSubscribe != nil { + s.Events.OnSubscribe(pk.Topics[i], cl.Info(), pk.Qoss[i]) } atomic.AddInt64(&s.System.Subscriptions, 1) } @@ -791,8 +791,8 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error for i := 0; i < len(pk.Topics); i++ { q := s.Topics.Unsubscribe(pk.Topics[i], cl.ID) if q { - if s.Events.OnTopicUnsubscribe != nil { - s.Events.OnTopicUnsubscribe(pk.Topics[i], cl.ID) + if s.Events.OnUnsubscribe != nil { + s.Events.OnUnsubscribe(pk.Topics[i], cl.Info()) } atomic.AddInt64(&s.System.Subscriptions, -1) } @@ -1014,13 +1014,13 @@ func (s *Server) loadServerInfo(v persistence.ServerInfo) { func (s *Server) loadSubscriptions(v []persistence.Subscription) { for _, sub := range v { if s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) { - if s.Events.OnTopicSubscribe != nil { - s.Events.OnTopicSubscribe(sub.Filter, sub.Client, sub.QoS) + if cl, ok := s.Clients.Get(sub.Client); ok { + cl.NoteSubscription(sub.Filter, sub.QoS) + if s.Events.OnSubscribe != nil { + s.Events.OnSubscribe(sub.Filter, cl.Info(), sub.QoS) + } } } - if cl, ok := s.Clients.Get(sub.Client); ok { - cl.NoteSubscription(sub.Filter, sub.QoS) - } } } diff --git a/server/server_test.go b/server/server_test.go index c153de2..4827add 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1946,6 +1946,15 @@ func TestServerProcessSubscribeInvalid(t *testing.T) { func TestServerProcessSubscribe(t *testing.T) { s, cl, r, w := setupClient() + subscribeEvent := "" + subscribeClient := "" + s.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) { + if filter == "a/b/c" { + subscribeEvent = "a/b/c" + subscribeClient = cl.ID + } + } + s.Topics.RetainMessage(packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, @@ -1995,6 +2004,8 @@ func TestServerProcessSubscribe(t *testing.T) { require.Equal(t, byte(1), cl.Subscriptions["d/e/f"]) require.Equal(t, topics.Subscriptions{cl.ID: 0}, s.Topics.Subscribers("a/b/c")) require.Equal(t, topics.Subscriptions{cl.ID: 1}, s.Topics.Subscribers("d/e/f")) + require.Equal(t, "a/b/c", subscribeEvent) + require.Equal(t, cl.ID, subscribeClient) } func TestServerProcessSubscribeFailACL(t *testing.T) { @@ -2114,6 +2125,16 @@ func TestServerProcessUnsubscribeInvalid(t *testing.T) { func TestServerProcessUnsubscribe(t *testing.T) { s, cl, r, w := setupClient() + + unsubscribeEvent := "" + unsubscribeClient := "" + s.Events.OnUnsubscribe = func(filter string, cl events.Client) { + if filter == "a/b/c" { + unsubscribeEvent = "a/b/c" + unsubscribeClient = cl.ID + } + } + s.Clients.Add(cl) s.Topics.Subscribe("a/b/c", cl.ID, 0) s.Topics.Subscribe("d/e/f", cl.ID, 1) @@ -2155,6 +2176,9 @@ func TestServerProcessUnsubscribe(t *testing.T) { require.NotEmpty(t, s.Topics.Subscribers("a/b/+")) require.Contains(t, cl.Subscriptions, "a/b/+") + + require.Equal(t, "a/b/c", unsubscribeEvent) + require.Equal(t, cl.ID, unsubscribeClient) } func TestServerProcessUnsubscribeWriteError(t *testing.T) {