diff --git a/examples/direct/main.go b/examples/direct/main.go new file mode 100644 index 0000000..72c7c99 --- /dev/null +++ b/examples/direct/main.go @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co +// SPDX-FileContributor: mochi-co + +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/mochi-mqtt/server/v2/hooks/auth" + + mqtt "github.com/mochi-mqtt/server/v2" + "github.com/mochi-mqtt/server/v2/packets" +) + +func main() { + sigs := make(chan os.Signal, 1) + done := make(chan bool, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigs + done <- true + }() + + server := mqtt.New(nil) + _ = server.AddHook(new(auth.AllowHook), nil) + + // Start the server + go func() { + err := server.Serve() + if err != nil { + log.Fatal(err) + } + }() + + // Demonstration of using an inline client to directly subscribe to a topic and receive a message when + // that subscription is activated. The inline subscription method uses the same internal subscription logic + // as used for external (normal) clients. + go func() { + // Inline subscriptions can also receive retained messages on subscription. + _ = server.Publish("direct/retained", []byte("retained message"), true, 0) + _ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0) + + // Subscribe to a filter and handle any received messages via a callback function. + callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) { + server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload)) + } + server.Log.Info("inline client subscribing") + server.Subscribe("direct/#", 1, callbackFn) + server.Subscribe("direct/#", 2, callbackFn) + }() + + // There is a shorthand convenience function, Publish, for easily sending + // publish packets if you are not concerned with creating your own packets. + go func() { + for range time.Tick(time.Second * 3) { + err := server.Publish("direct/publish", []byte("scheduled message"), false, 0) + if err != nil { + server.Log.Error("server.Publish", "error", err) + } + server.Log.Info("main.go issued direct message to direct/publish") + } + }() + + go func() { + time.Sleep(time.Second * 10) + // Unsubscribe from the same filter to stop receiving messages. + server.Log.Info("inline client unsubscribing") + server.Unsubscribe("direct/#", 1) + }() + // If you want to have more control over your packets, you can directly inject a packet of any kind into the broker. + //go func() { + // for range time.Tick(time.Second * 5) { + // err := server.InjectPacket(cl, packets.Packet{ + // FixedHeader: packets.FixedHeader{ + // Type: packets.Publish, + // }, + // TopicName: "direct/publish", + // Payload: []byte("injected scheduled message"), + // }) + // if err != nil { + // log.Fatal(err) + // } + // } + //}() + + <-done + server.Log.Warn("caught signal, stopping...") + _ = server.Close() + server.Log.Info("main.go finished") +} diff --git a/packets/codes.go b/packets/codes.go index 154d7ae..5af1b74 100644 --- a/packets/codes.go +++ b/packets/codes.go @@ -126,6 +126,7 @@ var ( ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"} ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} + ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."} // MQTTv3 specific bytes. Err3UnsupportedProtocolVersion = Code{Code: 0x01} diff --git a/server.go b/server.go index f0235b5..19e0829 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,8 @@ import ( const ( Version = "2.3.0" // the current server version. defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes + LocalListener = "local" + InlineClientId = "inline" ) var ( @@ -109,15 +111,16 @@ type Options struct { // Server is an MQTT broker server. It should be created with server.New() // in order to ensure all the internal fields are correctly populated. type Server struct { - Options *Options // configurable server options - Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections - Clients *Clients // clients known to the broker - Topics *TopicsIndex // an index of topic filter subscriptions and retained messages - Info *system.Info // values about the server commonly known as $SYS topics - loop *loop // loop contains tickers for the system event loop - done chan bool // indicate that the server is ending - Log *slog.Logger // minimal no-alloc logger - hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage. + Options *Options // configurable server options + Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections + Clients *Clients // clients known to the broker + Topics *TopicsIndex // an index of topic filter subscriptions and retained messages + Info *system.Info // values about the server commonly known as $SYS topics + loop *loop // loop contains tickers for the system event loop + done chan bool // indicate that the server is ending + Log *slog.Logger // minimal no-alloc logger + hooks *Hooks // hooks contains hooks for extra functionality such as auth and persistent storage. + inlineClient *Client // inlineClient is a special client used for inline subscriptions and inline Publish. } // loop contains interval tickers for the system events loop. @@ -170,6 +173,8 @@ func New(opts *Options) *Server { Log: opts.Logger, }, } + s.inlineClient = s.NewClient(nil, LocalListener, InlineClientId, true) + s.Clients.Add(s.inlineClient) return s } @@ -649,8 +654,7 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error { // to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the // outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { - cl := s.NewClient(nil, "local", "inline", true) - return s.InjectPacket(cl, packets.Packet{ + return s.InjectPacket(s.inlineClient, packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, Qos: qos, @@ -662,6 +666,64 @@ func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) er }) } +// Subscribe adds an inline subscription for the specified topic filter and subscription identifier +// with the provided handler function. +func (s *Server) Subscribe(filter string, subscriptionId int, handler InlineSubFn) error { + if handler == nil { + return packets.ErrInlineSubscriptionHandlerInvalid + } else if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + subscription := packets.Subscription{ + Identifier: subscriptionId, + Filter: filter, + } + + pk := s.hooks.OnSubscribe(s.inlineClient, packets.Packet{ // subscribe like a normal client. + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Subscribe}, + Filters: packets.Subscriptions{subscription}, + }) + + inlineSubscription := InlineSubscription{ + Subscription: subscription, + Handler: handler, + } + + s.Topics.InlineSubscribe(inlineSubscription) + s.hooks.OnSubscribed(s.inlineClient, pk, []byte{packets.CodeSuccess.Code}) + + // Handling retained messages. + for _, pkv := range s.Topics.Messages(filter) { // [MQTT-3.8.4-4] + handler(s.inlineClient, inlineSubscription.Subscription, pkv) + } + return nil +} + +// Unsubscribe removes an inline subscription for the specified subscription and topic filter. +// It allows you to unsubscribe a specific subscription from the internal subscription +// associated with the given topic filter. +func (s *Server) Unsubscribe(filter string, subscriptionId int) error { + if !IsValidFilter(filter, false) { + return packets.ErrTopicFilterInvalid + } + + pk := s.hooks.OnUnsubscribe(s.inlineClient, packets.Packet{ + Origin: s.inlineClient.ID, + FixedHeader: packets.FixedHeader{Type: packets.Unsubscribe}, + Filters: packets.Subscriptions{ + { + Identifier: subscriptionId, + Filter: filter, + }, + }, + }) + + s.Topics.InlineUnsubscribe(subscriptionId, filter) + s.hooks.OnUnsubscribed(s.inlineClient, pk) + return nil +} + // InjectPacket injects a packet into the broker as if it were sent from the specified client. // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { @@ -736,7 +798,10 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { s.retainMessage(cl, pk) } - if pk.FixedHeader.Qos == 0 { + // If it's inlineClient, it can't handle PUBREC and PUBREL. + // When it publishes a package with a qos > 0, the server treats + // the package as qos=0, and the client receives it as qos=1 or 2. + if pk.FixedHeader.Qos == 0 || cl.Net.Inline { s.publishToSubscribers(pk) s.hooks.OnPublished(cl, pk) return nil @@ -809,6 +874,10 @@ func (s *Server) publishToSubscribers(pk packets.Packet) { subscribers.MergeSharedSelected() } + for _, inlineSubscription := range subscribers.InlineSubscriptions { + inlineSubscription.Handler(s.inlineClient, inlineSubscription.Subscription, pk) + } + for id, subs := range subscribers.Subscriptions { if cl, ok := s.Clients.Get(id); ok { _, err := s.publishToClient(cl, subs, pk) diff --git a/server_test.go b/server_test.go index e9ea837..4930e8a 100644 --- a/server_test.go +++ b/server_test.go @@ -117,6 +117,8 @@ func TestNew(t *testing.T) { require.NotNil(t, s.hooks) require.NotNil(t, s.hooks.Log) require.NotNil(t, s.done) + require.NotNil(t, s.inlineClient) + require.Equal(t, 1, s.Clients.Len()) } func TestNewNilOpts(t *testing.T) { @@ -348,9 +350,13 @@ func TestEstablishConnection(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Todo: + // s.Clients is already empty here. Is it necessary to check v.StopCause()? + + // for _, v := range s.Clients.GetAll() { + // require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect + // } require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) @@ -409,9 +415,11 @@ func TestEstablishConnectionReadError(t *testing.T) { err := <-o require.Error(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect ret := <-recv require.Equal(t, append( @@ -467,9 +475,11 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect connackPlusPacket := append( packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, @@ -657,9 +667,11 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { err := <-o require.NoError(t, err) - for _, v := range s.Clients.GetAll() { - require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect - } + + // Retrieve the client corresponding to the Client Identifier. + retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) + require.True(t, ok) + require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover) @@ -2973,9 +2985,11 @@ func TestServerLoadClients(t *testing.T) { } s := newServer() - require.Equal(t, 0, s.Clients.Len()) + // Here, server.Clients also includes a server.inlineClient, + // so the total number of clients here should be increased by 1 + require.Equal(t, 1, s.Clients.Len()) s.loadClients(v) - require.Equal(t, 3, s.Clients.Len()) + require.Equal(t, 4, s.Clients.Len()) cl, ok := s.Clients.Get("mochi") require.True(t, ok) require.Equal(t, "mochi", cl.ID) @@ -3003,7 +3017,10 @@ func TestServerLoadInflightMessages(t *testing.T) { {ID: "zen"}, {ID: "mochi-co"}, }) - require.Equal(t, 3, s.Clients.Len()) + + // server.Clients also includes a server.inlineClient, + // so the total number of clients here should be increased by 1 + require.Equal(t, 4, s.Clients.Len()) v := []storage.Message{ {Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"}, @@ -3159,11 +3176,15 @@ func TestServerClearExpiredClients(t *testing.T) { cl2.Properties.Props.SessionExpiryIntervalFlag = true s.Clients.Add(cl2) - require.Equal(t, 4, s.Clients.Len()) + // server.Clients also includes a server.inlineClient, + // so the total number of clients here should be increased by 1 + require.Equal(t, 5, s.Clients.Len()) s.clearExpiredClients(n) - require.Equal(t, 2, s.Clients.Len()) + // server.Clients also includes a server.inlineClient, + // so the total number of clients here should be increased by 1 + require.Equal(t, 3, s.Clients.Len()) } func TestLoadServerInfoRestoreOnRestart(t *testing.T) { @@ -3182,3 +3203,297 @@ func TestAtomicItoa(t *testing.T) { ip := &i require.Equal(t, "22", AtomicItoa(ip)) } + +func TestServerSubscribe(t *testing.T) { + + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + s := New(nil) + require.NotNil(t, s) + + tt := []struct { + desc string + filter string + identifier int + handler InlineSubFn + expect error + }{ + { + desc: "subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe", + filter: "a/b/c", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe d/e/f", + filter: "d/e/f", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "re-subscribe d/e/f by different identifier", + filter: "d/e/f", + identifier: 2, + handler: handler, + expect: nil, + }, + { + desc: "subscribe different handler", + filter: "a/b/c", + identifier: 1, + handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {}, + expect: nil, + }, + { + desc: "subscribe $SYS/info", + filter: "$SYS/info", + identifier: 1, + handler: handler, + expect: nil, + }, + { + desc: "subscribe invalied ###", + filter: "###", + identifier: 1, + handler: handler, + expect: packets.ErrTopicFilterInvalid, + }, + { + desc: "subscribe invalid handler", + filter: "a/b/c", + identifier: 1, + handler: nil, + expect: packets.ErrInlineSubscriptionHandlerInvalid, + }, + } + + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler)) + }) + } +} + +func TestServerUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + s := New(nil) + err := s.Subscribe("a/b/c", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 1, handler) + require.Nil(t, err) + + err = s.Subscribe("d/e/f", 2, handler) + require.Nil(t, err) + + err = s.Unsubscribe("a/b/c", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 1) + require.Nil(t, err) + + err = s.Unsubscribe("d/e/f", 2) + require.Nil(t, err) + + err = s.Unsubscribe("not/exist", 1) + require.Nil(t, err) + + err = s.Unsubscribe("#/#/invalid", 1) + require.Equal(t, packets.ErrTopicFilterInvalid, err) +} + +func TestPublishToInlineSubscriber(t *testing.T) { + s := newServer() + finishCh := make(chan bool) + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.publishToSubscribers(pkx) + }() + + require.Equal(t, true, <-finishCh) +} + +func TestPublishToInlineSubscribersDiffrentFilter(t *testing.T) { + s := newServer() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.publishToSubscribers(pkx) + + pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet + s.publishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestPublishToInlineSubscribersDiffrentIdentifier(t *testing.T) { + s := newServer() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + s.publishToSubscribers(pkx) + }() + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetain(t *testing.T) { + s := newServer() + subNumber := 1 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + require.Equal(t, true, <-finishCh) +} + +func TestServerSubscribeWithRetainDiffrentFilter(t *testing.T) { + s := newServer() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("mochi mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "z/e/n", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} + +func TestServerSubscribeWithRetainDiffrentIdentifier(t *testing.T) { + s := newServer() + subNumber := 2 + finishCh := make(chan bool, subNumber) + + retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), retained) + + err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 1, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) { + require.Equal(t, []byte("hello mochi"), pk.Payload) + require.Equal(t, InlineClientId, cl.ID) + require.Equal(t, LocalListener, cl.Net.Listener) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, 2, sub.Identifier) + finishCh <- true + }) + require.Nil(t, err) + + for i := 0; i < subNumber; i++ { + require.Equal(t, true, <-finishCh) + } +} diff --git a/topics.go b/topics.go index be7c9a3..fe8b378 100644 --- a/topics.go +++ b/topics.go @@ -186,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio return m } +// InlineSubFn is the signature for a callback function which will be called +// when an inline client receives a message on a topic it is subscribed to. +// The sub argument contains information about the subscription that was matched for any filters. +type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet) + +// InlineSubscriptions represents a map of internal subscriptions keyed on client. +type InlineSubscriptions struct { + internal map[int]InlineSubscription + sync.RWMutex +} + +// NewInlineSubscriptions returns a new instance of InlineSubscriptions. +func NewInlineSubscriptions() *InlineSubscriptions { + return &InlineSubscriptions{ + internal: map[int]InlineSubscription{}, + } +} + +// Add adds a new internal subscription for a client id. +func (s *InlineSubscriptions) Add(val InlineSubscription) { + s.Lock() + defer s.Unlock() + s.internal[val.Identifier] = val +} + +// GetAll returns all internal subscriptions. +func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription { + s.RLock() + defer s.RUnlock() + m := map[int]InlineSubscription{} + for k, v := range s.internal { + m[k] = v + } + return m +} + +// Get returns an internal subscription for a client id. +func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) { + s.RLock() + defer s.RUnlock() + val, ok = s.internal[id] + return val, ok +} + +// Len returns the number of internal subscriptions. +func (s *InlineSubscriptions) Len() int { + s.RLock() + defer s.RUnlock() + val := len(s.internal) + return val +} + +// Delete removes an internal subscription by the client id. +func (s *InlineSubscriptions) Delete(id int) { + s.Lock() + defer s.Unlock() + delete(s.internal, id) +} + // Subscriptions is a map of subscriptions keyed on client. type Subscriptions struct { internal map[string]packets.Subscription @@ -244,11 +303,17 @@ func (s *Subscriptions) Delete(id string) { // ClientSubscriptions is a map of aggregated subscriptions for a client. type ClientSubscriptions map[string]packets.Subscription +type InlineSubscription struct { + packets.Subscription + Handler InlineSubFn +} + // Subscribers contains the shared and non-shared subscribers matching a topic. type Subscribers struct { - Shared map[string]map[string]packets.Subscription - SharedSelected map[string]packets.Subscription - Subscriptions map[string]packets.Subscription + Shared map[string]map[string]packets.Subscription + SharedSelected map[string]packets.Subscription + Subscriptions map[string]packets.Subscription + InlineSubscriptions map[int]InlineSubscription } // SelectShared returns one subscriber for each shared subscription group. @@ -298,6 +363,39 @@ func NewTopicsIndex() *TopicsIndex { } } +// InlineSubscribe adds a new internal subscription for a topic filter, returning +// true if the subscription was new. +func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool { + x.root.Lock() + defer x.root.Unlock() + + var existed bool + n := x.set(subscription.Filter, 0) + _, existed = n.inlineSubscriptions.Get(subscription.Identifier) + n.inlineSubscriptions.Add(subscription) + + return !existed +} + +// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client, +// returning true if the subscription existed. +func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool { + x.root.Lock() + defer x.root.Unlock() + + particle := x.seek(filter, 0) + if particle == nil { + return false + } + + particle.inlineSubscriptions.Delete(id) + + if particle.inlineSubscriptions.Len() == 0 { + x.trim(particle) + } + return true +} + // Subscribe adds a new subscription for a client to a topic filter, returning // true if the subscription was new. func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool { @@ -484,9 +582,10 @@ func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []pack // their subscription ids and highest qos. func (x *TopicsIndex) Subscribers(topic string) *Subscribers { return x.scanSubscribers(topic, 0, nil, &Subscribers{ - Shared: map[string]map[string]packets.Subscription{}, - SharedSelected: map[string]packets.Subscription{}, - Subscriptions: map[string]packets.Subscription{}, + Shared: map[string]map[string]packets.Subscription{}, + SharedSelected: map[string]packets.Subscription{}, + Subscriptions: map[string]packets.Subscription{}, + InlineSubscriptions: map[int]InlineSubscription{}, }) } @@ -508,10 +607,12 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su } else { x.gatherSubscriptions(topic, particle, subs) x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) if wild := particle.particles.get("#"); wild != nil && partKey != "+" { x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2 x.gatherSharedSubscriptions(wild, subs) + x.gatherInlineSubscriptions(particle, subs) } } } @@ -520,6 +621,7 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su if particle := n.particles.get("#"); particle != nil { x.gatherSubscriptions(topic, particle, subs) x.gatherSharedSubscriptions(particle, subs) + x.gatherInlineSubscriptions(particle, subs) } return subs @@ -562,6 +664,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr } } +// gatherSharedSubscriptions gathers all inline subscriptions for a particle. +func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) { + if subs.InlineSubscriptions == nil { + subs.InlineSubscriptions = map[int]InlineSubscription{} + } + + for id, inline := range particle.inlineSubscriptions.GetAll() { + subs.InlineSubscriptions[id] = inline + } +} + // isolateParticle extracts a particle between d / and d+1 / without allocations. func isolateParticle(filter string, d int) (particle string, hasNext bool) { var next, end int @@ -633,23 +746,25 @@ func IsValidFilter(filter string, forPublish bool) bool { // particle is a child node on the tree. type particle struct { - key string // the key of the particle - parent *particle // a pointer to the parent of the particle - particles particles // a map of child particles - subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address - shared *SharedSubscriptions // a map of shared subscriptions keyed on group name - retainPath string // path of a retained message - sync.Mutex // mutex for when making changes to the particle + key string // the key of the particle + parent *particle // a pointer to the parent of the particle + particles particles // a map of child particles + subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address + shared *SharedSubscriptions // a map of shared subscriptions keyed on group name + inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle + retainPath string // path of a retained message + sync.Mutex // mutex for when making changes to the particle } // newParticle returns a pointer to a new instance of particle. func newParticle(key string, parent *particle) *particle { return &particle{ - key: key, - parent: parent, - particles: newParticles(), - subscriptions: NewSubscriptions(), - shared: NewSharedSubscriptions(), + key: key, + parent: parent, + particles: newParticles(), + subscriptions: NewSubscriptions(), + shared: NewSharedSubscriptions(), + inlineSubscriptions: NewInlineSubscriptions(), } } diff --git a/topics_test.go b/topics_test.go index b6f7551..fa6528e 100644 --- a/topics_test.go +++ b/topics_test.go @@ -5,6 +5,7 @@ package mqtt import ( + "fmt" "testing" "github.com/mochi-mqtt/server/v2/packets" @@ -853,3 +854,215 @@ func TestNewTopicAliases(t *testing.T) { require.NotNil(t, a.Outbound) require.Equal(t, uint16(5), a.Outbound.maximum) } + +func TestNewInlineSubscriptions(t *testing.T) { + subscriptions := NewInlineSubscriptions() + require.NotNil(t, subscriptions) + require.NotNil(t, subscriptions.internal) + require.Equal(t, 0, subscriptions.Len()) +} + +func TestInlineSubscriptionAdd(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) +} + +func TestInlineSubscriptionGet(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + sub, ok := subscriptions.Get(1) + require.True(t, ok) + require.Equal(t, "a/b/c", sub.Filter) + require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler)) + + _, ok = subscriptions.Get(999) + require.False(t, ok) +} + +func TestInlineSubscriptionsGetAll(t *testing.T) { + subscriptions := NewInlineSubscriptions() + + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}, + }) + subscriptions.Add(InlineSubscription{ + Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3}, + }) + + allSubs := subscriptions.GetAll() + require.Len(t, allSubs, 3) + require.Contains(t, allSubs, 1) + require.Contains(t, allSubs, 2) + require.Contains(t, allSubs, 3) +} + +func TestInlineSubscriptionDelete(t *testing.T) { + subscriptions := NewInlineSubscriptions() + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + subscription := InlineSubscription{ + Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}, + Handler: handler, + } + subscriptions.Add(subscription) + + subscriptions.Delete(1) + _, ok := subscriptions.Get(1) + require.False(t, ok) + require.Empty(t, subscriptions.GetAll()) + require.Zero(t, subscriptions.Len()) +} + +func TestInlineSubscribe(t *testing.T) { + + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + tt := []struct { + desc string + filter string + subscription InlineSubscription + wasNew bool + }{ + { + desc: "subscribe", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: true, + }, + { + desc: "subscribe existed", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}}, + wasNew: false, + }, + { + desc: "subscribe different identifier", + filter: "a/b/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}}, + wasNew: true, + }, + { + desc: "subscribe case sensitive didnt exist", + filter: "A/B/c", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}}, + wasNew: true, + }, + { + desc: "wildcard+ sub", + filter: "d/+", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}}, + wasNew: true, + }, + { + desc: "wildcard# sub", + filter: "d/e/#", + subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}}, + wasNew: true, + }, + } + + index := NewTopicsIndex() + for _, tx := range tt { + t.Run(tx.desc, func(t *testing.T) { + require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription)) + }) + } + + final := index.root.particles.get("a").particles.get("b").particles.get("c") + require.NotNil(t, final) +} + +func TestInlineUnsubscribe(t *testing.T) { + handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) { + // handler logic + } + + index := NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index = NewTopicsIndex() + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}}) + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}}) + sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}}) + sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok := index.InlineUnsubscribe(1, "a/b/c/d") + require.True(t, ok) + require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c")) + + sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1) + require.NotNil(t, sub) + require.True(t, exists) + + ok = index.InlineUnsubscribe(1, "d/e/f") + require.True(t, ok) + require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f")) + + ok = index.InlineUnsubscribe(1, "not/exist") + require.False(t, ok) +}