mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-26 20:21:12 +08:00
Another code implementation for Inline Client Subscriptions. (#284)
* Another code implementation for Inline Client Subscriptions. * Added a few test cases. * Changed the return value types of Server.Unsubscribe() and Subscribe() to boolean. * Implementing the delivery of retained messages and supporting multiple callbacks per topic using different inline client IDs. * Added validation checks for the legality of the inline client id during Subscribe and Unsubscribe. * Added validation checks for the legality of the client during Subscribe and Unsubscribe. * Fixed the TestServerSubscribe/invalid_client_id test case failure. * Add Server.inlineClient and Temporarily removing test cases for better code review readability. * Using server.inlineClient in server.InjectPacket(). * After unsubscribing, if there are other subscriptions in particle.inlineSubscriptions, particle cannot be deleted. * Add comments to particle.inlineSubscriptions and modify to return ErrTopicFilterInvalid when the topic is invalid during subscription. * Fixed some test case failures caused by adding inlineClient to the server. * More test cases have been added. * Optimization of test case code. * Modify server.go: When used as a publisher, treat the qos of inline client-published messages as 0. * Resolve conflict.
This commit is contained in:
95
examples/direct/main.go
Normal file
95
examples/direct/main.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
|
||||||
|
// SPDX-FileContributor: mochi-co
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mochi-mqtt/server/v2/hooks/auth"
|
||||||
|
|
||||||
|
mqtt "github.com/mochi-mqtt/server/v2"
|
||||||
|
"github.com/mochi-mqtt/server/v2/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
sigs := make(chan os.Signal, 1)
|
||||||
|
done := make(chan bool, 1)
|
||||||
|
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
<-sigs
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
server := mqtt.New(nil)
|
||||||
|
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
go func() {
|
||||||
|
err := server.Serve()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Demonstration of using an inline client to directly subscribe to a topic and receive a message when
|
||||||
|
// that subscription is activated. The inline subscription method uses the same internal subscription logic
|
||||||
|
// as used for external (normal) clients.
|
||||||
|
go func() {
|
||||||
|
// Inline subscriptions can also receive retained messages on subscription.
|
||||||
|
_ = server.Publish("direct/retained", []byte("retained message"), true, 0)
|
||||||
|
_ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0)
|
||||||
|
|
||||||
|
// Subscribe to a filter and handle any received messages via a callback function.
|
||||||
|
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
|
||||||
|
}
|
||||||
|
server.Log.Info("inline client subscribing")
|
||||||
|
server.Subscribe("direct/#", 1, callbackFn)
|
||||||
|
server.Subscribe("direct/#", 2, callbackFn)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// There is a shorthand convenience function, Publish, for easily sending
|
||||||
|
// publish packets if you are not concerned with creating your own packets.
|
||||||
|
go func() {
|
||||||
|
for range time.Tick(time.Second * 3) {
|
||||||
|
err := server.Publish("direct/publish", []byte("scheduled message"), false, 0)
|
||||||
|
if err != nil {
|
||||||
|
server.Log.Error("server.Publish", "error", err)
|
||||||
|
}
|
||||||
|
server.Log.Info("main.go issued direct message to direct/publish")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
|
// Unsubscribe from the same filter to stop receiving messages.
|
||||||
|
server.Log.Info("inline client unsubscribing")
|
||||||
|
server.Unsubscribe("direct/#", 1)
|
||||||
|
}()
|
||||||
|
// If you want to have more control over your packets, you can directly inject a packet of any kind into the broker.
|
||||||
|
//go func() {
|
||||||
|
// for range time.Tick(time.Second * 5) {
|
||||||
|
// err := server.InjectPacket(cl, packets.Packet{
|
||||||
|
// FixedHeader: packets.FixedHeader{
|
||||||
|
// Type: packets.Publish,
|
||||||
|
// },
|
||||||
|
// TopicName: "direct/publish",
|
||||||
|
// Payload: []byte("injected scheduled message"),
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//}()
|
||||||
|
|
||||||
|
<-done
|
||||||
|
server.Log.Warn("caught signal, stopping...")
|
||||||
|
_ = server.Close()
|
||||||
|
server.Log.Info("main.go finished")
|
||||||
|
}
|
@@ -126,6 +126,7 @@ var (
|
|||||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||||
|
ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
|
||||||
|
|
||||||
// MQTTv3 specific bytes.
|
// MQTTv3 specific bytes.
|
||||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||||
|
93
server.go
93
server.go
@@ -28,6 +28,8 @@ import (
|
|||||||
const (
|
const (
|
||||||
Version = "2.3.0" // the current server version.
|
Version = "2.3.0" // the current server version.
|
||||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||||
|
LocalListener = "local"
|
||||||
|
InlineClientId = "inline"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -109,15 +111,16 @@ type Options struct {
|
|||||||
// Server is an MQTT broker server. It should be created with server.New()
|
// Server is an MQTT broker server. It should be created with server.New()
|
||||||
// in order to ensure all the internal fields are correctly populated.
|
// in order to ensure all the internal fields are correctly populated.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Options *Options // configurable server options
|
Options *Options // configurable server options
|
||||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections
|
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections
|
||||||
Clients *Clients // clients known to the broker
|
Clients *Clients // clients known to the broker
|
||||||
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
|
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
|
||||||
Info *system.Info // values about the server commonly known as $SYS topics
|
Info *system.Info // values about the server commonly known as $SYS topics
|
||||||
loop *loop // loop contains tickers for the system event loop
|
loop *loop // loop contains tickers for the system event loop
|
||||||
done chan bool // indicate that the server is ending
|
done chan bool // indicate that the server is ending
|
||||||
Log *slog.Logger // minimal no-alloc logger
|
Log *slog.Logger // minimal no-alloc logger
|
||||||
hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage.
|
hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage.
|
||||||
|
inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish.
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop contains interval tickers for the system events loop.
|
// loop contains interval tickers for the system events loop.
|
||||||
@@ -170,6 +173,8 @@ func New(opts *Options) *Server {
|
|||||||
Log: opts.Logger,
|
Log: opts.Logger,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true)
|
||||||
|
s.Clients.Add(s.inlineClient)
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@@ -649,8 +654,7 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
|
|||||||
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
|
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
|
||||||
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
|
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
|
||||||
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
|
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
|
||||||
cl := s.NewClient(nil, "local", "inline", true)
|
return s.InjectPacket(s.inlineClient, packets.Packet{
|
||||||
return s.InjectPacket(cl, packets.Packet{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
FixedHeader: packets.FixedHeader{
|
||||||
Type: packets.Publish,
|
Type: packets.Publish,
|
||||||
Qos: qos,
|
Qos: qos,
|
||||||
@@ -662,6 +666,64 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Subscribe adds an inline subscription for the specified topic filter and subscription identifier
|
||||||
|
// with the provided handler function.
|
||||||
|
func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error {
|
||||||
|
if handler == nil {
|
||||||
|
return packets.ErrInlineSubscriptionHandlerInvalid
|
||||||
|
} else if !IsValidFilter(filter, false) {
|
||||||
|
return packets.ErrTopicFilterInvalid
|
||||||
|
}
|
||||||
|
subscription := packets.Subscription{
|
||||||
|
Identifier: subscriptionId,
|
||||||
|
Filter: filter,
|
||||||
|
}
|
||||||
|
|
||||||
|
pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client.
|
||||||
|
Origin: s.inlineClient.ID,
|
||||||
|
FixedHeader: packets.FixedHeader{Type: packets.Subscribe},
|
||||||
|
Filters: packets.Subscriptions{subscription},
|
||||||
|
})
|
||||||
|
|
||||||
|
inlineSubscription := InlineSubscription{
|
||||||
|
Subscription: subscription,
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Topics.InlineSubscribe(inlineSubscription)
|
||||||
|
s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code})
|
||||||
|
|
||||||
|
// Handling retained messages.
|
||||||
|
for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4]
|
||||||
|
handler(s.inlineClient, inlineSubscription.Subscription, pkv)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes an inline subscription for the specified subscription and topic filter.
|
||||||
|
// It allows you to unsubscribe a specific subscription from the internal subscription
|
||||||
|
// associated with the given topic filter.
|
||||||
|
func (s *Server) Unsubscribe(filter string, subscriptionId int) error {
|
||||||
|
if !IsValidFilter(filter, false) {
|
||||||
|
return packets.ErrTopicFilterInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{
|
||||||
|
Origin: s.inlineClient.ID,
|
||||||
|
FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe},
|
||||||
|
Filters: packets.Subscriptions{
|
||||||
|
{
|
||||||
|
Identifier: subscriptionId,
|
||||||
|
Filter: filter,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Topics.InlineUnsubscribe(subscriptionId, filter)
|
||||||
|
s.hooks.OnUnsubscribed(s.inlineClient, pk)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// InjectPacket injects a packet into the broker as if it were sent from the specified client.
|
// InjectPacket injects a packet into the broker as if it were sent from the specified client.
|
||||||
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
|
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
|
||||||
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
|
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
|
||||||
@@ -736,7 +798,10 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
|||||||
s.retainMessage(cl, pk)
|
s.retainMessage(cl, pk)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pk.FixedHeader.Qos == 0 {
|
// If it's inlineClient, it can't handle PUBREC and PUBREL.
|
||||||
|
// When it publishes a package with a qos > 0, the server treats
|
||||||
|
// the package as qos=0, and the client receives it as qos=1 or 2.
|
||||||
|
if pk.FixedHeader.Qos == 0 || cl.Net.Inline {
|
||||||
s.publishToSubscribers(pk)
|
s.publishToSubscribers(pk)
|
||||||
s.hooks.OnPublished(cl, pk)
|
s.hooks.OnPublished(cl, pk)
|
||||||
return nil
|
return nil
|
||||||
@@ -809,6 +874,10 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
|
|||||||
subscribers.MergeSharedSelected()
|
subscribers.MergeSharedSelected()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, inlineSubscription := range subscribers.InlineSubscriptions {
|
||||||
|
inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk)
|
||||||
|
}
|
||||||
|
|
||||||
for id, subs := range subscribers.Subscriptions {
|
for id, subs := range subscribers.Subscriptions {
|
||||||
if cl, ok := s.Clients.Get(id); ok {
|
if cl, ok := s.Clients.Get(id); ok {
|
||||||
_, err := s.publishToClient(cl, subs, pk)
|
_, err := s.publishToClient(cl, subs, pk)
|
||||||
|
349
server_test.go
349
server_test.go
@@ -117,6 +117,8 @@ func TestNew(t *testing.T) {
|
|||||||
require.NotNil(t, s.hooks)
|
require.NotNil(t, s.hooks)
|
||||||
require.NotNil(t, s.hooks.Log)
|
require.NotNil(t, s.hooks.Log)
|
||||||
require.NotNil(t, s.done)
|
require.NotNil(t, s.done)
|
||||||
|
require.NotNil(t, s.inlineClient)
|
||||||
|
require.Equal(t, 1, s.Clients.Len())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewNilOpts(t *testing.T) {
|
func TestNewNilOpts(t *testing.T) {
|
||||||
@@ -348,9 +350,13 @@ func TestEstablishConnection(t *testing.T) {
|
|||||||
|
|
||||||
err := <-o
|
err := <-o
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for _, v := range s.Clients.GetAll() {
|
|
||||||
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
|
// Todo:
|
||||||
}
|
// s.Clients is already empty here. Is it necessary to check v.StopCause()?
|
||||||
|
|
||||||
|
// for _, v := range s.Clients.GetAll() {
|
||||||
|
// require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
|
||||||
|
// }
|
||||||
|
|
||||||
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
|
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
|
||||||
|
|
||||||
@@ -409,9 +415,11 @@ func TestEstablishConnectionReadError(t *testing.T) {
|
|||||||
|
|
||||||
err := <-o
|
err := <-o
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
for _, v := range s.Clients.GetAll() {
|
|
||||||
require.ErrorIs(t, v.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect
|
// Retrieve the client corresponding to the Client Identifier.
|
||||||
}
|
retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect
|
||||||
|
|
||||||
ret := <-recv
|
ret := <-recv
|
||||||
require.Equal(t, append(
|
require.Equal(t, append(
|
||||||
@@ -467,9 +475,11 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
|
|||||||
|
|
||||||
err := <-o
|
err := <-o
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for _, v := range s.Clients.GetAll() {
|
|
||||||
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
|
// Retrieve the client corresponding to the Client Identifier.
|
||||||
}
|
retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
|
||||||
|
|
||||||
connackPlusPacket := append(
|
connackPlusPacket := append(
|
||||||
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes,
|
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes,
|
||||||
@@ -657,9 +667,11 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
|
|||||||
|
|
||||||
err := <-o
|
err := <-o
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for _, v := range s.Clients.GetAll() {
|
|
||||||
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
|
// Retrieve the client corresponding to the Client Identifier.
|
||||||
}
|
retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
|
||||||
|
|
||||||
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
|
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
|
||||||
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover)
|
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover)
|
||||||
@@ -2973,9 +2985,11 @@ func TestServerLoadClients(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := newServer()
|
s := newServer()
|
||||||
require.Equal(t, 0, s.Clients.Len())
|
// Here, server.Clients also includes a server.inlineClient,
|
||||||
|
// so the total number of clients here should be increased by 1
|
||||||
|
require.Equal(t, 1, s.Clients.Len())
|
||||||
s.loadClients(v)
|
s.loadClients(v)
|
||||||
require.Equal(t, 3, s.Clients.Len())
|
require.Equal(t, 4, s.Clients.Len())
|
||||||
cl, ok := s.Clients.Get("mochi")
|
cl, ok := s.Clients.Get("mochi")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, "mochi", cl.ID)
|
require.Equal(t, "mochi", cl.ID)
|
||||||
@@ -3003,7 +3017,10 @@ func TestServerLoadInflightMessages(t *testing.T) {
|
|||||||
{ID: "zen"},
|
{ID: "zen"},
|
||||||
{ID: "mochi-co"},
|
{ID: "mochi-co"},
|
||||||
})
|
})
|
||||||
require.Equal(t, 3, s.Clients.Len())
|
|
||||||
|
// server.Clients also includes a server.inlineClient,
|
||||||
|
// so the total number of clients here should be increased by 1
|
||||||
|
require.Equal(t, 4, s.Clients.Len())
|
||||||
|
|
||||||
v := []storage.Message{
|
v := []storage.Message{
|
||||||
{Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"},
|
{Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"},
|
||||||
@@ -3159,11 +3176,15 @@ func TestServerClearExpiredClients(t *testing.T) {
|
|||||||
cl2.Properties.Props.SessionExpiryIntervalFlag = true
|
cl2.Properties.Props.SessionExpiryIntervalFlag = true
|
||||||
s.Clients.Add(cl2)
|
s.Clients.Add(cl2)
|
||||||
|
|
||||||
require.Equal(t, 4, s.Clients.Len())
|
// server.Clients also includes a server.inlineClient,
|
||||||
|
// so the total number of clients here should be increased by 1
|
||||||
|
require.Equal(t, 5, s.Clients.Len())
|
||||||
|
|
||||||
s.clearExpiredClients(n)
|
s.clearExpiredClients(n)
|
||||||
|
|
||||||
require.Equal(t, 2, s.Clients.Len())
|
// server.Clients also includes a server.inlineClient,
|
||||||
|
// so the total number of clients here should be increased by 1
|
||||||
|
require.Equal(t, 3, s.Clients.Len())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadServerInfoRestoreOnRestart(t *testing.T) {
|
func TestLoadServerInfoRestoreOnRestart(t *testing.T) {
|
||||||
@@ -3182,3 +3203,297 @@ func TestAtomicItoa(t *testing.T) {
|
|||||||
ip := &i
|
ip := &i
|
||||||
require.Equal(t, "22", AtomicItoa(ip))
|
require.Equal(t, "22", AtomicItoa(ip))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerSubscribe(t *testing.T) {
|
||||||
|
|
||||||
|
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
// handler logic
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(nil)
|
||||||
|
require.NotNil(t, s)
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
desc string
|
||||||
|
filter string
|
||||||
|
identifier int
|
||||||
|
handler InlineSubFn
|
||||||
|
expect error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "subscribe",
|
||||||
|
filter: "a/b/c",
|
||||||
|
identifier: 1,
|
||||||
|
handler: handler,
|
||||||
|
expect: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "re-subscribe",
|
||||||
|
filter: "a/b/c",
|
||||||
|
identifier: 1,
|
||||||
|
handler: handler,
|
||||||
|
expect: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe d/e/f",
|
||||||
|
filter: "d/e/f",
|
||||||
|
identifier: 1,
|
||||||
|
handler: handler,
|
||||||
|
expect: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "re-subscribe d/e/f by different identifier",
|
||||||
|
filter: "d/e/f",
|
||||||
|
identifier: 2,
|
||||||
|
handler: handler,
|
||||||
|
expect: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe different handler",
|
||||||
|
filter: "a/b/c",
|
||||||
|
identifier: 1,
|
||||||
|
handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {},
|
||||||
|
expect: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe $SYS/info",
|
||||||
|
filter: "$SYS/info",
|
||||||
|
identifier: 1,
|
||||||
|
handler: handler,
|
||||||
|
expect: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe invalied ###",
|
||||||
|
filter: "###",
|
||||||
|
identifier: 1,
|
||||||
|
handler: handler,
|
||||||
|
expect: packets.ErrTopicFilterInvalid,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe invalid handler",
|
||||||
|
filter: "a/b/c",
|
||||||
|
identifier: 1,
|
||||||
|
handler: nil,
|
||||||
|
expect: packets.ErrInlineSubscriptionHandlerInvalid,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tx := range tt {
|
||||||
|
t.Run(tx.desc, func(t *testing.T) {
|
||||||
|
require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerUnsubscribe(t *testing.T) {
|
||||||
|
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
// handler logic
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(nil)
|
||||||
|
err := s.Subscribe("a/b/c", 1, handler)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Subscribe("d/e/f", 1, handler)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Subscribe("d/e/f", 2, handler)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Unsubscribe("a/b/c", 1)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Unsubscribe("d/e/f", 1)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Unsubscribe("d/e/f", 2)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Unsubscribe("not/exist", 1)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Unsubscribe("#/#/invalid", 1)
|
||||||
|
require.Equal(t, packets.ErrTopicFilterInvalid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublishToInlineSubscriber(t *testing.T) {
|
||||||
|
s := newServer()
|
||||||
|
finishCh := make(chan bool)
|
||||||
|
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||||
|
s.publishToSubscribers(pkx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.Equal(t, true, <-finishCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublishToInlineSubscribersDiffrentFilter(t *testing.T) {
|
||||||
|
s := newServer()
|
||||||
|
subNumber := 2
|
||||||
|
finishCh := make(chan bool, subNumber)
|
||||||
|
|
||||||
|
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("mochi mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "z/e/n", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||||
|
s.publishToSubscribers(pkx)
|
||||||
|
|
||||||
|
pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet
|
||||||
|
s.publishToSubscribers(pkx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < subNumber; i++ {
|
||||||
|
require.Equal(t, true, <-finishCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublishToInlineSubscribersDiffrentIdentifier(t *testing.T) {
|
||||||
|
s := newServer()
|
||||||
|
subNumber := 2
|
||||||
|
finishCh := make(chan bool, subNumber)
|
||||||
|
|
||||||
|
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 2, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||||
|
s.publishToSubscribers(pkx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < subNumber; i++ {
|
||||||
|
require.Equal(t, true, <-finishCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerSubscribeWithRetain(t *testing.T) {
|
||||||
|
s := newServer()
|
||||||
|
subNumber := 1
|
||||||
|
finishCh := make(chan bool, subNumber)
|
||||||
|
|
||||||
|
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
|
||||||
|
require.Equal(t, int64(1), retained)
|
||||||
|
|
||||||
|
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, true, <-finishCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerSubscribeWithRetainDiffrentFilter(t *testing.T) {
|
||||||
|
s := newServer()
|
||||||
|
subNumber := 2
|
||||||
|
finishCh := make(chan bool, subNumber)
|
||||||
|
|
||||||
|
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
|
||||||
|
require.Equal(t, int64(1), retained)
|
||||||
|
retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet)
|
||||||
|
require.Equal(t, int64(1), retained)
|
||||||
|
|
||||||
|
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("mochi mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "z/e/n", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < subNumber; i++ {
|
||||||
|
require.Equal(t, true, <-finishCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerSubscribeWithRetainDiffrentIdentifier(t *testing.T) {
|
||||||
|
s := newServer()
|
||||||
|
subNumber := 2
|
||||||
|
finishCh := make(chan bool, subNumber)
|
||||||
|
|
||||||
|
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
|
||||||
|
require.Equal(t, int64(1), retained)
|
||||||
|
|
||||||
|
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 1, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
require.Equal(t, []byte("hello mochi"), pk.Payload)
|
||||||
|
require.Equal(t, InlineClientId, cl.ID)
|
||||||
|
require.Equal(t, LocalListener, cl.Net.Listener)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, 2, sub.Identifier)
|
||||||
|
finishCh <- true
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
for i := 0; i < subNumber; i++ {
|
||||||
|
require.Equal(t, true, <-finishCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
151
topics.go
151
topics.go
@@ -186,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InlineSubFn is the signature for a callback function which will be called
|
||||||
|
// when an inline client receives a message on a topic it is subscribed to.
|
||||||
|
// The sub argument contains information about the subscription that was matched for any filters.
|
||||||
|
type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
|
||||||
|
|
||||||
|
// InlineSubscriptions represents a map of internal subscriptions keyed on client.
|
||||||
|
type InlineSubscriptions struct {
|
||||||
|
internal map[int]InlineSubscription
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
|
||||||
|
func NewInlineSubscriptions() *InlineSubscriptions {
|
||||||
|
return &InlineSubscriptions{
|
||||||
|
internal: map[int]InlineSubscription{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a new internal subscription for a client id.
|
||||||
|
func (s *InlineSubscriptions) Add(val InlineSubscription) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
s.internal[val.Identifier] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all internal subscriptions.
|
||||||
|
func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
m := map[int]InlineSubscription{}
|
||||||
|
for k, v := range s.internal {
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns an internal subscription for a client id.
|
||||||
|
func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
val, ok = s.internal[id]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of internal subscriptions.
|
||||||
|
func (s *InlineSubscriptions) Len() int {
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
val := len(s.internal)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes an internal subscription by the client id.
|
||||||
|
func (s *InlineSubscriptions) Delete(id int) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
delete(s.internal, id)
|
||||||
|
}
|
||||||
|
|
||||||
// Subscriptions is a map of subscriptions keyed on client.
|
// Subscriptions is a map of subscriptions keyed on client.
|
||||||
type Subscriptions struct {
|
type Subscriptions struct {
|
||||||
internal map[string]packets.Subscription
|
internal map[string]packets.Subscription
|
||||||
@@ -244,11 +303,17 @@ func (s *Subscriptions) Delete(id string) {
|
|||||||
// ClientSubscriptions is a map of aggregated subscriptions for a client.
|
// ClientSubscriptions is a map of aggregated subscriptions for a client.
|
||||||
type ClientSubscriptions map[string]packets.Subscription
|
type ClientSubscriptions map[string]packets.Subscription
|
||||||
|
|
||||||
|
type InlineSubscription struct {
|
||||||
|
packets.Subscription
|
||||||
|
Handler InlineSubFn
|
||||||
|
}
|
||||||
|
|
||||||
// Subscribers contains the shared and non-shared subscribers matching a topic.
|
// Subscribers contains the shared and non-shared subscribers matching a topic.
|
||||||
type Subscribers struct {
|
type Subscribers struct {
|
||||||
Shared map[string]map[string]packets.Subscription
|
Shared map[string]map[string]packets.Subscription
|
||||||
SharedSelected map[string]packets.Subscription
|
SharedSelected map[string]packets.Subscription
|
||||||
Subscriptions map[string]packets.Subscription
|
Subscriptions map[string]packets.Subscription
|
||||||
|
InlineSubscriptions map[int]InlineSubscription
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectShared returns one subscriber for each shared subscription group.
|
// SelectShared returns one subscriber for each shared subscription group.
|
||||||
@@ -298,6 +363,39 @@ func NewTopicsIndex() *TopicsIndex {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InlineSubscribe adds a new internal subscription for a topic filter, returning
|
||||||
|
// true if the subscription was new.
|
||||||
|
func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
|
||||||
|
x.root.Lock()
|
||||||
|
defer x.root.Unlock()
|
||||||
|
|
||||||
|
var existed bool
|
||||||
|
n := x.set(subscription.Filter, 0)
|
||||||
|
_, existed = n.inlineSubscriptions.Get(subscription.Identifier)
|
||||||
|
n.inlineSubscriptions.Add(subscription)
|
||||||
|
|
||||||
|
return !existed
|
||||||
|
}
|
||||||
|
|
||||||
|
// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
|
||||||
|
// returning true if the subscription existed.
|
||||||
|
func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
|
||||||
|
x.root.Lock()
|
||||||
|
defer x.root.Unlock()
|
||||||
|
|
||||||
|
particle := x.seek(filter, 0)
|
||||||
|
if particle == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
particle.inlineSubscriptions.Delete(id)
|
||||||
|
|
||||||
|
if particle.inlineSubscriptions.Len() == 0 {
|
||||||
|
x.trim(particle)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// Subscribe adds a new subscription for a client to a topic filter, returning
|
// Subscribe adds a new subscription for a client to a topic filter, returning
|
||||||
// true if the subscription was new.
|
// true if the subscription was new.
|
||||||
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
||||||
@@ -484,9 +582,10 @@ func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []pack
|
|||||||
// their subscription ids and highest qos.
|
// their subscription ids and highest qos.
|
||||||
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
|
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
|
||||||
return x.scanSubscribers(topic, 0, nil, &Subscribers{
|
return x.scanSubscribers(topic, 0, nil, &Subscribers{
|
||||||
Shared: map[string]map[string]packets.Subscription{},
|
Shared: map[string]map[string]packets.Subscription{},
|
||||||
SharedSelected: map[string]packets.Subscription{},
|
SharedSelected: map[string]packets.Subscription{},
|
||||||
Subscriptions: map[string]packets.Subscription{},
|
Subscriptions: map[string]packets.Subscription{},
|
||||||
|
InlineSubscriptions: map[int]InlineSubscription{},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -508,10 +607,12 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
|
|||||||
} else {
|
} else {
|
||||||
x.gatherSubscriptions(topic, particle, subs)
|
x.gatherSubscriptions(topic, particle, subs)
|
||||||
x.gatherSharedSubscriptions(particle, subs)
|
x.gatherSharedSubscriptions(particle, subs)
|
||||||
|
x.gatherInlineSubscriptions(particle, subs)
|
||||||
|
|
||||||
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
|
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
|
||||||
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
||||||
x.gatherSharedSubscriptions(wild, subs)
|
x.gatherSharedSubscriptions(wild, subs)
|
||||||
|
x.gatherInlineSubscriptions(particle, subs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -520,6 +621,7 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
|
|||||||
if particle := n.particles.get("#"); particle != nil {
|
if particle := n.particles.get("#"); particle != nil {
|
||||||
x.gatherSubscriptions(topic, particle, subs)
|
x.gatherSubscriptions(topic, particle, subs)
|
||||||
x.gatherSharedSubscriptions(particle, subs)
|
x.gatherSharedSubscriptions(particle, subs)
|
||||||
|
x.gatherInlineSubscriptions(particle, subs)
|
||||||
}
|
}
|
||||||
|
|
||||||
return subs
|
return subs
|
||||||
@@ -562,6 +664,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
|
||||||
|
func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
|
||||||
|
if subs.InlineSubscriptions == nil {
|
||||||
|
subs.InlineSubscriptions = map[int]InlineSubscription{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, inline := range particle.inlineSubscriptions.GetAll() {
|
||||||
|
subs.InlineSubscriptions[id] = inline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
||||||
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
||||||
var next, end int
|
var next, end int
|
||||||
@@ -633,23 +746,25 @@ func IsValidFilter(filter string, forPublish bool) bool {
|
|||||||
|
|
||||||
// particle is a child node on the tree.
|
// particle is a child node on the tree.
|
||||||
type particle struct {
|
type particle struct {
|
||||||
key string // the key of the particle
|
key string // the key of the particle
|
||||||
parent *particle // a pointer to the parent of the particle
|
parent *particle // a pointer to the parent of the particle
|
||||||
particles particles // a map of child particles
|
particles particles // a map of child particles
|
||||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||||
retainPath string // path of a retained message
|
inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
|
||||||
sync.Mutex // mutex for when making changes to the particle
|
retainPath string // path of a retained message
|
||||||
|
sync.Mutex // mutex for when making changes to the particle
|
||||||
}
|
}
|
||||||
|
|
||||||
// newParticle returns a pointer to a new instance of particle.
|
// newParticle returns a pointer to a new instance of particle.
|
||||||
func newParticle(key string, parent *particle) *particle {
|
func newParticle(key string, parent *particle) *particle {
|
||||||
return &particle{
|
return &particle{
|
||||||
key: key,
|
key: key,
|
||||||
parent: parent,
|
parent: parent,
|
||||||
particles: newParticles(),
|
particles: newParticles(),
|
||||||
subscriptions: NewSubscriptions(),
|
subscriptions: NewSubscriptions(),
|
||||||
shared: NewSharedSubscriptions(),
|
shared: NewSharedSubscriptions(),
|
||||||
|
inlineSubscriptions: NewInlineSubscriptions(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
213
topics_test.go
213
topics_test.go
@@ -5,6 +5,7 @@
|
|||||||
package mqtt
|
package mqtt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mochi-mqtt/server/v2/packets"
|
"github.com/mochi-mqtt/server/v2/packets"
|
||||||
@@ -853,3 +854,215 @@ func TestNewTopicAliases(t *testing.T) {
|
|||||||
require.NotNil(t, a.Outbound)
|
require.NotNil(t, a.Outbound)
|
||||||
require.Equal(t, uint16(5), a.Outbound.maximum)
|
require.Equal(t, uint16(5), a.Outbound.maximum)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewInlineSubscriptions(t *testing.T) {
|
||||||
|
subscriptions := NewInlineSubscriptions()
|
||||||
|
require.NotNil(t, subscriptions)
|
||||||
|
require.NotNil(t, subscriptions.internal)
|
||||||
|
require.Equal(t, 0, subscriptions.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineSubscriptionAdd(t *testing.T) {
|
||||||
|
subscriptions := NewInlineSubscriptions()
|
||||||
|
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
// handler logic
|
||||||
|
}
|
||||||
|
|
||||||
|
subscription := InlineSubscription{
|
||||||
|
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
subscriptions.Add(subscription)
|
||||||
|
|
||||||
|
sub, ok := subscriptions.Get(1)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineSubscriptionGet(t *testing.T) {
|
||||||
|
subscriptions := NewInlineSubscriptions()
|
||||||
|
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
// handler logic
|
||||||
|
}
|
||||||
|
|
||||||
|
subscription := InlineSubscription{
|
||||||
|
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
subscriptions.Add(subscription)
|
||||||
|
|
||||||
|
sub, ok := subscriptions.Get(1)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "a/b/c", sub.Filter)
|
||||||
|
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
|
||||||
|
|
||||||
|
_, ok = subscriptions.Get(999)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineSubscriptionsGetAll(t *testing.T) {
|
||||||
|
subscriptions := NewInlineSubscriptions()
|
||||||
|
|
||||||
|
subscriptions.Add(InlineSubscription{
|
||||||
|
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||||
|
})
|
||||||
|
subscriptions.Add(InlineSubscription{
|
||||||
|
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||||
|
})
|
||||||
|
subscriptions.Add(InlineSubscription{
|
||||||
|
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2},
|
||||||
|
})
|
||||||
|
subscriptions.Add(InlineSubscription{
|
||||||
|
Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3},
|
||||||
|
})
|
||||||
|
|
||||||
|
allSubs := subscriptions.GetAll()
|
||||||
|
require.Len(t, allSubs, 3)
|
||||||
|
require.Contains(t, allSubs, 1)
|
||||||
|
require.Contains(t, allSubs, 2)
|
||||||
|
require.Contains(t, allSubs, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineSubscriptionDelete(t *testing.T) {
|
||||||
|
subscriptions := NewInlineSubscriptions()
|
||||||
|
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
// handler logic
|
||||||
|
}
|
||||||
|
|
||||||
|
subscription := InlineSubscription{
|
||||||
|
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
subscriptions.Add(subscription)
|
||||||
|
|
||||||
|
subscriptions.Delete(1)
|
||||||
|
_, ok := subscriptions.Get(1)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Empty(t, subscriptions.GetAll())
|
||||||
|
require.Zero(t, subscriptions.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineSubscribe(t *testing.T) {
|
||||||
|
|
||||||
|
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
// handler logic
|
||||||
|
}
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
desc string
|
||||||
|
filter string
|
||||||
|
subscription InlineSubscription
|
||||||
|
wasNew bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "subscribe",
|
||||||
|
filter: "a/b/c",
|
||||||
|
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
|
||||||
|
wasNew: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe existed",
|
||||||
|
filter: "a/b/c",
|
||||||
|
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
|
||||||
|
wasNew: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe different identifier",
|
||||||
|
filter: "a/b/c",
|
||||||
|
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}},
|
||||||
|
wasNew: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "subscribe case sensitive didnt exist",
|
||||||
|
filter: "A/B/c",
|
||||||
|
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}},
|
||||||
|
wasNew: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "wildcard+ sub",
|
||||||
|
filter: "d/+",
|
||||||
|
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}},
|
||||||
|
wasNew: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "wildcard# sub",
|
||||||
|
filter: "d/e/#",
|
||||||
|
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}},
|
||||||
|
wasNew: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
index := NewTopicsIndex()
|
||||||
|
for _, tx := range tt {
|
||||||
|
t.Run(tx.desc, func(t *testing.T) {
|
||||||
|
require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||||
|
require.NotNil(t, final)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineUnsubscribe(t *testing.T) {
|
||||||
|
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
|
||||||
|
// handler logic
|
||||||
|
}
|
||||||
|
|
||||||
|
index := NewTopicsIndex()
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
|
||||||
|
sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
index = NewTopicsIndex()
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
|
||||||
|
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||||
|
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}})
|
||||||
|
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}})
|
||||||
|
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||||
|
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
|
||||||
|
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}})
|
||||||
|
sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
ok := index.InlineUnsubscribe(1, "a/b/c/d")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||||
|
|
||||||
|
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
|
||||||
|
require.NotNil(t, sub)
|
||||||
|
require.True(t, exists)
|
||||||
|
|
||||||
|
ok = index.InlineUnsubscribe(1, "d/e/f")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f"))
|
||||||
|
|
||||||
|
ok = index.InlineUnsubscribe(1, "not/exist")
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user