Another code implementation for Inline Client Subscriptions. (#284)

* Another code implementation for Inline Client Subscriptions.

* Added a few test cases.

* Changed the return value types of Server.Unsubscribe() and Subscribe() to boolean.

* Implementing the delivery of retained messages and supporting multiple callbacks per topic using different inline client IDs.

* Added validation checks for the legality of the inline client id during Subscribe and Unsubscribe.

* Added validation checks for the legality of the client during Subscribe and Unsubscribe.

* Fixed the TestServerSubscribe/invalid_client_id test case failure.

* Add Server.inlineClient and Temporarily removing test cases for better code review readability.

* Using server.inlineClient in server.InjectPacket().

* After unsubscribing, if there are other subscriptions in particle.inlineSubscriptions, particle cannot be deleted.

* Add comments to particle.inlineSubscriptions and modify to return ErrTopicFilterInvalid when the topic is invalid during subscription.

* Fixed some test case failures caused by adding inlineClient to the server.

* More test cases have been added.

* Optimization of test case code.

* Modify server.go: When used as a publisher, treat the qos of inline client-published messages as 0.

* Resolve conflict.
This commit is contained in:
werbenhu
2023-09-09 03:45:08 +08:00
committed by GitHub
parent 44bac0adc5
commit 1574443981
6 changed files with 855 additions and 47 deletions

95
examples/direct/main.go Normal file
View File

@@ -0,0 +1,95 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-mqtt/server/v2/hooks/auth"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
// Start the server
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
// Demonstration of using an inline client to directly subscribe to a topic and receive a message when
// that subscription is activated. The inline subscription method uses the same internal subscription logic
// as used for external (normal) clients.
go func() {
// Inline subscriptions can also receive retained messages on subscription.
_ = server.Publish("direct/retained", []byte("retained message"), true, 0)
_ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0)
// Subscribe to a filter and handle any received messages via a callback function.
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
}
server.Log.Info("inline client subscribing")
server.Subscribe("direct/#", 1, callbackFn)
server.Subscribe("direct/#", 2, callbackFn)
}()
// There is a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
go func() {
for range time.Tick(time.Second * 3) {
err := server.Publish("direct/publish", []byte("scheduled message"), false, 0)
if err != nil {
server.Log.Error("server.Publish", "error", err)
}
server.Log.Info("main.go issued direct message to direct/publish")
}
}()
go func() {
time.Sleep(time.Second * 10)
// Unsubscribe from the same filter to stop receiving messages.
server.Log.Info("inline client unsubscribing")
server.Unsubscribe("direct/#", 1)
}()
// If you want to have more control over your packets, you can directly inject a packet of any kind into the broker.
//go func() {
// for range time.Tick(time.Second * 5) {
// err := server.InjectPacket(cl, packets.Packet{
// FixedHeader: packets.FixedHeader{
// Type: packets.Publish,
// },
// TopicName: "direct/publish",
// Payload: []byte("injected scheduled message"),
// })
// if err != nil {
// log.Fatal(err)
// }
// }
//}()
<-done
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -126,6 +126,7 @@ var (
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"} ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
// MQTTv3 specific bytes. // MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01} Err3UnsupportedProtocolVersion = Code{Code: 0x01}

View File

@@ -28,6 +28,8 @@ import (
const ( const (
Version = "2.3.0" // the current server version. Version = "2.3.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
LocalListener = "local"
InlineClientId = "inline"
) )
var ( var (
@@ -118,6 +120,7 @@ type Server struct {
done chan bool // indicate that the server is ending done chan bool // indicate that the server is ending
Log *slog.Logger // minimal no-alloc logger Log *slog.Logger // minimal no-alloc logger
hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage. hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage.
inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish.
} }
// loop contains interval tickers for the system events loop. // loop contains interval tickers for the system events loop.
@@ -170,6 +173,8 @@ func New(opts *Options) *Server {
Log: opts.Logger, Log: opts.Logger,
}, },
} }
s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true)
s.Clients.Add(s.inlineClient)
return s return s
} }
@@ -649,8 +654,7 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the // to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). // outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
cl := s.NewClient(nil, "local", "inline", true) return s.InjectPacket(s.inlineClient, packets.Packet{
return s.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
Type: packets.Publish, Type: packets.Publish,
Qos: qos, Qos: qos,
@@ -662,6 +666,64 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er
}) })
} }
// Subscribe adds an inline subscription for the specified topic filter and subscription identifier
// with the provided handler function.
func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error {
if handler == nil {
return packets.ErrInlineSubscriptionHandlerInvalid
} else if !IsValidFilter(filter, false) {
return packets.ErrTopicFilterInvalid
}
subscription := packets.Subscription{
Identifier: subscriptionId,
Filter: filter,
}
pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client.
Origin: s.inlineClient.ID,
FixedHeader: packets.FixedHeader{Type: packets.Subscribe},
Filters: packets.Subscriptions{subscription},
})
inlineSubscription := InlineSubscription{
Subscription: subscription,
Handler: handler,
}
s.Topics.InlineSubscribe(inlineSubscription)
s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code})
// Handling retained messages.
for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4]
handler(s.inlineClient, inlineSubscription.Subscription, pkv)
}
return nil
}
// Unsubscribe removes an inline subscription for the specified subscription and topic filter.
// It allows you to unsubscribe a specific subscription from the internal subscription
// associated with the given topic filter.
func (s *Server) Unsubscribe(filter string, subscriptionId int) error {
if !IsValidFilter(filter, false) {
return packets.ErrTopicFilterInvalid
}
pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{
Origin: s.inlineClient.ID,
FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe},
Filters: packets.Subscriptions{
{
Identifier: subscriptionId,
Filter: filter,
},
},
})
s.Topics.InlineUnsubscribe(subscriptionId, filter)
s.hooks.OnUnsubscribed(s.inlineClient, pk)
return nil
}
// InjectPacket injects a packet into the broker as if it were sent from the specified client. // InjectPacket injects a packet into the broker as if it were sent from the specified client.
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
@@ -736,7 +798,10 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.retainMessage(cl, pk) s.retainMessage(cl, pk)
} }
if pk.FixedHeader.Qos == 0 { // If it's inlineClient, it can't handle PUBREC and PUBREL.
// When it publishes a package with a qos > 0, the server treats
// the package as qos=0, and the client receives it as qos=1 or 2.
if pk.FixedHeader.Qos == 0 || cl.Net.Inline {
s.publishToSubscribers(pk) s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk) s.hooks.OnPublished(cl, pk)
return nil return nil
@@ -809,6 +874,10 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
subscribers.MergeSharedSelected() subscribers.MergeSharedSelected()
} }
for _, inlineSubscription := range subscribers.InlineSubscriptions {
inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk)
}
for id, subs := range subscribers.Subscriptions { for id, subs := range subscribers.Subscriptions {
if cl, ok := s.Clients.Get(id); ok { if cl, ok := s.Clients.Get(id); ok {
_, err := s.publishToClient(cl, subs, pk) _, err := s.publishToClient(cl, subs, pk)

View File

@@ -117,6 +117,8 @@ func TestNew(t *testing.T) {
require.NotNil(t, s.hooks) require.NotNil(t, s.hooks)
require.NotNil(t, s.hooks.Log) require.NotNil(t, s.hooks.Log)
require.NotNil(t, s.done) require.NotNil(t, s.done)
require.NotNil(t, s.inlineClient)
require.Equal(t, 1, s.Clients.Len())
} }
func TestNewNilOpts(t *testing.T) { func TestNewNilOpts(t *testing.T) {
@@ -348,9 +350,13 @@ func TestEstablishConnection(t *testing.T) {
err := <-o err := <-o
require.NoError(t, err) require.NoError(t, err)
for _, v := range s.Clients.GetAll() {
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect // Todo:
} // s.Clients is already empty here. Is it necessary to check v.StopCause()?
// for _, v := range s.Clients.GetAll() {
// require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
// }
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
@@ -409,9 +415,11 @@ func TestEstablishConnectionReadError(t *testing.T) {
err := <-o err := <-o
require.Error(t, err) require.Error(t, err)
for _, v := range s.Clients.GetAll() {
require.ErrorIs(t, v.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect // Retrieve the client corresponding to the Client Identifier.
} retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect
ret := <-recv ret := <-recv
require.Equal(t, append( require.Equal(t, append(
@@ -467,9 +475,11 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
err := <-o err := <-o
require.NoError(t, err) require.NoError(t, err)
for _, v := range s.Clients.GetAll() {
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect // Retrieve the client corresponding to the Client Identifier.
} retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
connackPlusPacket := append( connackPlusPacket := append(
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes,
@@ -657,9 +667,11 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
err := <-o err := <-o
require.NoError(t, err) require.NoError(t, err)
for _, v := range s.Clients.GetAll() {
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect // Retrieve the client corresponding to the Client Identifier.
} retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover)
@@ -2973,9 +2985,11 @@ func TestServerLoadClients(t *testing.T) {
} }
s := newServer() s := newServer()
require.Equal(t, 0, s.Clients.Len()) // Here, server.Clients also includes a server.inlineClient,
// so the total number of clients here should be increased by 1
require.Equal(t, 1, s.Clients.Len())
s.loadClients(v) s.loadClients(v)
require.Equal(t, 3, s.Clients.Len()) require.Equal(t, 4, s.Clients.Len())
cl, ok := s.Clients.Get("mochi") cl, ok := s.Clients.Get("mochi")
require.True(t, ok) require.True(t, ok)
require.Equal(t, "mochi", cl.ID) require.Equal(t, "mochi", cl.ID)
@@ -3003,7 +3017,10 @@ func TestServerLoadInflightMessages(t *testing.T) {
{ID: "zen"}, {ID: "zen"},
{ID: "mochi-co"}, {ID: "mochi-co"},
}) })
require.Equal(t, 3, s.Clients.Len())
// server.Clients also includes a server.inlineClient,
// so the total number of clients here should be increased by 1
require.Equal(t, 4, s.Clients.Len())
v := []storage.Message{ v := []storage.Message{
{Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"}, {Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"},
@@ -3159,11 +3176,15 @@ func TestServerClearExpiredClients(t *testing.T) {
cl2.Properties.Props.SessionExpiryIntervalFlag = true cl2.Properties.Props.SessionExpiryIntervalFlag = true
s.Clients.Add(cl2) s.Clients.Add(cl2)
require.Equal(t, 4, s.Clients.Len()) // server.Clients also includes a server.inlineClient,
// so the total number of clients here should be increased by 1
require.Equal(t, 5, s.Clients.Len())
s.clearExpiredClients(n) s.clearExpiredClients(n)
require.Equal(t, 2, s.Clients.Len()) // server.Clients also includes a server.inlineClient,
// so the total number of clients here should be increased by 1
require.Equal(t, 3, s.Clients.Len())
} }
func TestLoadServerInfoRestoreOnRestart(t *testing.T) { func TestLoadServerInfoRestoreOnRestart(t *testing.T) {
@@ -3182,3 +3203,297 @@ func TestAtomicItoa(t *testing.T) {
ip := &i ip := &i
require.Equal(t, "22", AtomicItoa(ip)) require.Equal(t, "22", AtomicItoa(ip))
} }
func TestServerSubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
s := New(nil)
require.NotNil(t, s)
tt := []struct {
desc string
filter string
identifier int
handler InlineSubFn
expect error
}{
{
desc: "subscribe",
filter: "a/b/c",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "re-subscribe",
filter: "a/b/c",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "subscribe d/e/f",
filter: "d/e/f",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "re-subscribe d/e/f by different identifier",
filter: "d/e/f",
identifier: 2,
handler: handler,
expect: nil,
},
{
desc: "subscribe different handler",
filter: "a/b/c",
identifier: 1,
handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {},
expect: nil,
},
{
desc: "subscribe $SYS/info",
filter: "$SYS/info",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "subscribe invalied ###",
filter: "###",
identifier: 1,
handler: handler,
expect: packets.ErrTopicFilterInvalid,
},
{
desc: "subscribe invalid handler",
filter: "a/b/c",
identifier: 1,
handler: nil,
expect: packets.ErrInlineSubscriptionHandlerInvalid,
},
}
for _, tx := range tt {
t.Run(tx.desc, func(t *testing.T) {
require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler))
})
}
}
func TestServerUnsubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
s := New(nil)
err := s.Subscribe("a/b/c", 1, handler)
require.Nil(t, err)
err = s.Subscribe("d/e/f", 1, handler)
require.Nil(t, err)
err = s.Subscribe("d/e/f", 2, handler)
require.Nil(t, err)
err = s.Unsubscribe("a/b/c", 1)
require.Nil(t, err)
err = s.Unsubscribe("d/e/f", 1)
require.Nil(t, err)
err = s.Unsubscribe("d/e/f", 2)
require.Nil(t, err)
err = s.Unsubscribe("not/exist", 1)
require.Nil(t, err)
err = s.Unsubscribe("#/#/invalid", 1)
require.Equal(t, packets.ErrTopicFilterInvalid, err)
}
func TestPublishToInlineSubscriber(t *testing.T) {
s := newServer()
finishCh := make(chan bool)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
s.publishToSubscribers(pkx)
}()
require.Equal(t, true, <-finishCh)
}
func TestPublishToInlineSubscribersDiffrentFilter(t *testing.T) {
s := newServer()
subNumber := 2
finishCh := make(chan bool, subNumber)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("mochi mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "z/e/n", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
s.publishToSubscribers(pkx)
pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet
s.publishToSubscribers(pkx)
}()
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}
func TestPublishToInlineSubscribersDiffrentIdentifier(t *testing.T) {
s := newServer()
subNumber := 2
finishCh := make(chan bool, subNumber)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 2, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
s.publishToSubscribers(pkx)
}()
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}
func TestServerSubscribeWithRetain(t *testing.T) {
s := newServer()
subNumber := 1
finishCh := make(chan bool, subNumber)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
require.Equal(t, true, <-finishCh)
}
func TestServerSubscribeWithRetainDiffrentFilter(t *testing.T) {
s := newServer()
subNumber := 2
finishCh := make(chan bool, subNumber)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet)
require.Equal(t, int64(1), retained)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("mochi mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "z/e/n", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}
func TestServerSubscribeWithRetainDiffrentIdentifier(t *testing.T) {
s := newServer()
subNumber := 2
finishCh := make(chan bool, subNumber)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 2, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}

115
topics.go
View File

@@ -186,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio
return m return m
} }
// InlineSubFn is the signature for a callback function which will be called
// when an inline client receives a message on a topic it is subscribed to.
// The sub argument contains information about the subscription that was matched for any filters.
type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
// InlineSubscriptions represents a map of internal subscriptions keyed on client.
type InlineSubscriptions struct {
internal map[int]InlineSubscription
sync.RWMutex
}
// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
func NewInlineSubscriptions() *InlineSubscriptions {
return &InlineSubscriptions{
internal: map[int]InlineSubscription{},
}
}
// Add adds a new internal subscription for a client id.
func (s *InlineSubscriptions) Add(val InlineSubscription) {
s.Lock()
defer s.Unlock()
s.internal[val.Identifier] = val
}
// GetAll returns all internal subscriptions.
func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
s.RLock()
defer s.RUnlock()
m := map[int]InlineSubscription{}
for k, v := range s.internal {
m[k] = v
}
return m
}
// Get returns an internal subscription for a client id.
func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
s.RLock()
defer s.RUnlock()
val, ok = s.internal[id]
return val, ok
}
// Len returns the number of internal subscriptions.
func (s *InlineSubscriptions) Len() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Delete removes an internal subscription by the client id.
func (s *InlineSubscriptions) Delete(id int) {
s.Lock()
defer s.Unlock()
delete(s.internal, id)
}
// Subscriptions is a map of subscriptions keyed on client. // Subscriptions is a map of subscriptions keyed on client.
type Subscriptions struct { type Subscriptions struct {
internal map[string]packets.Subscription internal map[string]packets.Subscription
@@ -244,11 +303,17 @@ func (s *Subscriptions) Delete(id string) {
// ClientSubscriptions is a map of aggregated subscriptions for a client. // ClientSubscriptions is a map of aggregated subscriptions for a client.
type ClientSubscriptions map[string]packets.Subscription type ClientSubscriptions map[string]packets.Subscription
type InlineSubscription struct {
packets.Subscription
Handler InlineSubFn
}
// Subscribers contains the shared and non-shared subscribers matching a topic. // Subscribers contains the shared and non-shared subscribers matching a topic.
type Subscribers struct { type Subscribers struct {
Shared map[string]map[string]packets.Subscription Shared map[string]map[string]packets.Subscription
SharedSelected map[string]packets.Subscription SharedSelected map[string]packets.Subscription
Subscriptions map[string]packets.Subscription Subscriptions map[string]packets.Subscription
InlineSubscriptions map[int]InlineSubscription
} }
// SelectShared returns one subscriber for each shared subscription group. // SelectShared returns one subscriber for each shared subscription group.
@@ -298,6 +363,39 @@ func NewTopicsIndex() *TopicsIndex {
} }
} }
// InlineSubscribe adds a new internal subscription for a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
n := x.set(subscription.Filter, 0)
_, existed = n.inlineSubscriptions.Get(subscription.Identifier)
n.inlineSubscriptions.Add(subscription)
return !existed
}
// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
// returning true if the subscription existed.
func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
x.root.Lock()
defer x.root.Unlock()
particle := x.seek(filter, 0)
if particle == nil {
return false
}
particle.inlineSubscriptions.Delete(id)
if particle.inlineSubscriptions.Len() == 0 {
x.trim(particle)
}
return true
}
// Subscribe adds a new subscription for a client to a topic filter, returning // Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new. // true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool { func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
@@ -487,6 +585,7 @@ func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
Shared: map[string]map[string]packets.Subscription{}, Shared: map[string]map[string]packets.Subscription{},
SharedSelected: map[string]packets.Subscription{}, SharedSelected: map[string]packets.Subscription{},
Subscriptions: map[string]packets.Subscription{}, Subscriptions: map[string]packets.Subscription{},
InlineSubscriptions: map[int]InlineSubscription{},
}) })
} }
@@ -508,10 +607,12 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
} else { } else {
x.gatherSubscriptions(topic, particle, subs) x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs) x.gatherSharedSubscriptions(particle, subs)
x.gatherInlineSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "+" { if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2 x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
x.gatherSharedSubscriptions(wild, subs) x.gatherSharedSubscriptions(wild, subs)
x.gatherInlineSubscriptions(particle, subs)
} }
} }
} }
@@ -520,6 +621,7 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
if particle := n.particles.get("#"); particle != nil { if particle := n.particles.get("#"); particle != nil {
x.gatherSubscriptions(topic, particle, subs) x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs) x.gatherSharedSubscriptions(particle, subs)
x.gatherInlineSubscriptions(particle, subs)
} }
return subs return subs
@@ -562,6 +664,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr
} }
} }
// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
if subs.InlineSubscriptions == nil {
subs.InlineSubscriptions = map[int]InlineSubscription{}
}
for id, inline := range particle.inlineSubscriptions.GetAll() {
subs.InlineSubscriptions[id] = inline
}
}
// isolateParticle extracts a particle between d / and d+1 / without allocations. // isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) { func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int var next, end int
@@ -638,6 +751,7 @@ type particle struct {
particles particles // a map of child particles particles particles // a map of child particles
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
retainPath string // path of a retained message retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle sync.Mutex // mutex for when making changes to the particle
} }
@@ -650,6 +764,7 @@ func newParticle(key string, parent *particle) *particle {
particles: newParticles(), particles: newParticles(),
subscriptions: NewSubscriptions(), subscriptions: NewSubscriptions(),
shared: NewSharedSubscriptions(), shared: NewSharedSubscriptions(),
inlineSubscriptions: NewInlineSubscriptions(),
} }
} }

View File

@@ -5,6 +5,7 @@
package mqtt package mqtt
import ( import (
"fmt"
"testing" "testing"
"github.com/mochi-mqtt/server/v2/packets" "github.com/mochi-mqtt/server/v2/packets"
@@ -853,3 +854,215 @@ func TestNewTopicAliases(t *testing.T) {
require.NotNil(t, a.Outbound) require.NotNil(t, a.Outbound)
require.Equal(t, uint16(5), a.Outbound.maximum) require.Equal(t, uint16(5), a.Outbound.maximum)
} }
func TestNewInlineSubscriptions(t *testing.T) {
subscriptions := NewInlineSubscriptions()
require.NotNil(t, subscriptions)
require.NotNil(t, subscriptions.internal)
require.Equal(t, 0, subscriptions.Len())
}
func TestInlineSubscriptionAdd(t *testing.T) {
subscriptions := NewInlineSubscriptions()
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
subscription := InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
Handler: handler,
}
subscriptions.Add(subscription)
sub, ok := subscriptions.Get(1)
require.True(t, ok)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
}
func TestInlineSubscriptionGet(t *testing.T) {
subscriptions := NewInlineSubscriptions()
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
subscription := InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
Handler: handler,
}
subscriptions.Add(subscription)
sub, ok := subscriptions.Get(1)
require.True(t, ok)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
_, ok = subscriptions.Get(999)
require.False(t, ok)
}
func TestInlineSubscriptionsGetAll(t *testing.T) {
subscriptions := NewInlineSubscriptions()
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
})
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
})
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2},
})
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3},
})
allSubs := subscriptions.GetAll()
require.Len(t, allSubs, 3)
require.Contains(t, allSubs, 1)
require.Contains(t, allSubs, 2)
require.Contains(t, allSubs, 3)
}
func TestInlineSubscriptionDelete(t *testing.T) {
subscriptions := NewInlineSubscriptions()
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
subscription := InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
Handler: handler,
}
subscriptions.Add(subscription)
subscriptions.Delete(1)
_, ok := subscriptions.Get(1)
require.False(t, ok)
require.Empty(t, subscriptions.GetAll())
require.Zero(t, subscriptions.Len())
}
func TestInlineSubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
tt := []struct {
desc string
filter string
subscription InlineSubscription
wasNew bool
}{
{
desc: "subscribe",
filter: "a/b/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
wasNew: true,
},
{
desc: "subscribe existed",
filter: "a/b/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
wasNew: false,
},
{
desc: "subscribe different identifier",
filter: "a/b/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}},
wasNew: true,
},
{
desc: "subscribe case sensitive didnt exist",
filter: "A/B/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}},
wasNew: true,
},
{
desc: "wildcard+ sub",
filter: "d/+",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}},
wasNew: true,
},
{
desc: "wildcard# sub",
filter: "d/e/#",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}},
wasNew: true,
},
}
index := NewTopicsIndex()
for _, tx := range tt {
t.Run(tx.desc, func(t *testing.T) {
require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription))
})
}
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
}
func TestInlineUnsubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
index := NewTopicsIndex()
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index = NewTopicsIndex()
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}})
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}})
sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
ok := index.InlineUnsubscribe(1, "a/b/c/d")
require.True(t, ok)
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
ok = index.InlineUnsubscribe(1, "d/e/f")
require.True(t, ok)
require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f"))
ok = index.InlineUnsubscribe(1, "not/exist")
require.False(t, ok)
}