mirror of
				https://github.com/mochi-mqtt/server.git
				synced 2025-10-31 11:36:25 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			843 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			843 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // SPDX-License-Identifier: MIT
 | |
| // SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
 | |
| // SPDX-FileContributor: mochi-co
 | |
| 
 | |
| package mqtt
 | |
| 
 | |
| import (
 | |
| 	"testing"
 | |
| 
 | |
| 	"github.com/mochi-co/mqtt/v2/packets"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	testGroup  = "testgroup"
 | |
| 	otherGroup = "other"
 | |
| )
 | |
| 
 | |
| func TestNewSharedSubscriptions(t *testing.T) {
 | |
| 	s := NewSharedSubscriptions()
 | |
| 	require.NotNil(t, s.internal)
 | |
| }
 | |
| 
 | |
| func TestSharedSubscriptionsAdd(t *testing.T) {
 | |
| 	s := NewSharedSubscriptions()
 | |
| 	s.Add(testGroup, "cl1", packets.Subscription{Filter: "a/b/c"})
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl1")
 | |
| }
 | |
| 
 | |
| func TestSharedSubscriptionsGet(t *testing.T) {
 | |
| 	s := NewSharedSubscriptions()
 | |
| 	s.Add(testGroup, "cl1", packets.Subscription{Qos: 2})
 | |
| 	s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl1")
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl2")
 | |
| 
 | |
| 	sub, ok := s.Get(testGroup, "cl2")
 | |
| 	require.Equal(t, true, ok)
 | |
| 	require.Equal(t, byte(2), sub.Qos)
 | |
| }
 | |
| 
 | |
| func TestSharedSubscriptionsGetAll(t *testing.T) {
 | |
| 	s := NewSharedSubscriptions()
 | |
| 	s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
 | |
| 	s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
 | |
| 	s.Add(otherGroup, "cl3", packets.Subscription{Qos: 2})
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl1")
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl2")
 | |
| 	require.Contains(t, s.internal, otherGroup)
 | |
| 	require.Contains(t, s.internal[otherGroup], "cl3")
 | |
| 
 | |
| 	subs := s.GetAll()
 | |
| 	require.Len(t, subs, 2)
 | |
| 	require.Len(t, subs[testGroup], 2)
 | |
| 	require.Len(t, subs[otherGroup], 1)
 | |
| }
 | |
| 
 | |
| func TestSharedSubscriptionsLen(t *testing.T) {
 | |
| 	s := NewSharedSubscriptions()
 | |
| 	s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
 | |
| 	s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
 | |
| 	s.Add(otherGroup, "cl2", packets.Subscription{Qos: 1})
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl1")
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl2")
 | |
| 	require.Contains(t, s.internal, otherGroup)
 | |
| 	require.Contains(t, s.internal[otherGroup], "cl2")
 | |
| 	require.Equal(t, 3, s.Len())
 | |
| 	require.Equal(t, 2, s.GroupLen())
 | |
| }
 | |
| 
 | |
| func TestSharedSubscriptionsDelete(t *testing.T) {
 | |
| 	s := NewSharedSubscriptions()
 | |
| 	s.Add(testGroup, "cl1", packets.Subscription{Qos: 1})
 | |
| 	s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl1")
 | |
| 	require.Contains(t, s.internal, testGroup)
 | |
| 	require.Contains(t, s.internal[testGroup], "cl2")
 | |
| 
 | |
| 	require.Equal(t, 2, s.Len())
 | |
| 
 | |
| 	s.Delete(testGroup, "cl1")
 | |
| 	_, ok := s.Get(testGroup, "cl1")
 | |
| 	require.False(t, ok)
 | |
| 	require.Equal(t, 1, s.GroupLen())
 | |
| 	require.Equal(t, 1, s.Len())
 | |
| 
 | |
| 	s.Delete(testGroup, "cl2")
 | |
| 	_, ok = s.Get(testGroup, "cl2")
 | |
| 	require.False(t, ok)
 | |
| 	require.Equal(t, 0, s.GroupLen())
 | |
| 	require.Equal(t, 0, s.Len())
 | |
| }
 | |
| 
 | |
| func TestNewSubscriptions(t *testing.T) {
 | |
| 	s := NewSubscriptions()
 | |
| 	require.NotNil(t, s.internal)
 | |
| }
 | |
| 
 | |
| func TestSubscriptionsAdd(t *testing.T) {
 | |
| 	s := NewSubscriptions()
 | |
| 	s.Add("cl1", packets.Subscription{})
 | |
| 	require.Contains(t, s.internal, "cl1")
 | |
| }
 | |
| 
 | |
| func TestSubscriptionsGet(t *testing.T) {
 | |
| 	s := NewSubscriptions()
 | |
| 	s.Add("cl1", packets.Subscription{Qos: 2})
 | |
| 	s.Add("cl2", packets.Subscription{Qos: 2})
 | |
| 	require.Contains(t, s.internal, "cl1")
 | |
| 	require.Contains(t, s.internal, "cl2")
 | |
| 
 | |
| 	sub, ok := s.Get("cl1")
 | |
| 	require.True(t, ok)
 | |
| 	require.Equal(t, byte(2), sub.Qos)
 | |
| }
 | |
| 
 | |
| func TestSubscriptionsGetAll(t *testing.T) {
 | |
| 	s := NewSubscriptions()
 | |
| 	s.Add("cl1", packets.Subscription{Qos: 0})
 | |
| 	s.Add("cl2", packets.Subscription{Qos: 1})
 | |
| 	s.Add("cl3", packets.Subscription{Qos: 2})
 | |
| 	require.Contains(t, s.internal, "cl1")
 | |
| 	require.Contains(t, s.internal, "cl2")
 | |
| 	require.Contains(t, s.internal, "cl3")
 | |
| 
 | |
| 	subs := s.GetAll()
 | |
| 	require.Len(t, subs, 3)
 | |
| }
 | |
| 
 | |
| func TestSubscriptionsLen(t *testing.T) {
 | |
| 	s := NewSubscriptions()
 | |
| 	s.Add("cl1", packets.Subscription{Qos: 0})
 | |
| 	s.Add("cl2", packets.Subscription{Qos: 1})
 | |
| 	require.Contains(t, s.internal, "cl1")
 | |
| 	require.Contains(t, s.internal, "cl2")
 | |
| 	require.Equal(t, 2, s.Len())
 | |
| }
 | |
| 
 | |
| func TestSubscriptionsDelete(t *testing.T) {
 | |
| 	s := NewSubscriptions()
 | |
| 	s.Add("cl1", packets.Subscription{Qos: 1})
 | |
| 	require.Contains(t, s.internal, "cl1")
 | |
| 
 | |
| 	s.Delete("cl1")
 | |
| 	_, ok := s.Get("cl1")
 | |
| 	require.False(t, ok)
 | |
| }
 | |
| 
 | |
| func TestNewTopicsIndex(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	require.NotNil(t, index)
 | |
| 	require.NotNil(t, index.root)
 | |
| }
 | |
| 
 | |
| func BenchmarkNewTopicsIndex(b *testing.B) {
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		NewTopicsIndex()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestSubscribe(t *testing.T) {
 | |
| 	tt := []struct {
 | |
| 		desc         string
 | |
| 		client       string
 | |
| 		filter       string
 | |
| 		subscription packets.Subscription
 | |
| 		wasNew       bool
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:   "subscribe",
 | |
| 			client: "cl1",
 | |
| 
 | |
| 			subscription: packets.Subscription{Filter: "a/b/c", Qos: 2},
 | |
| 			wasNew:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:   "subscribe existed",
 | |
| 			client: "cl1",
 | |
| 
 | |
| 			subscription: packets.Subscription{Filter: "a/b/c", Qos: 1},
 | |
| 			wasNew:       false,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:   "subscribe case sensitive didnt exist",
 | |
| 			client: "cl1",
 | |
| 
 | |
| 			subscription: packets.Subscription{Filter: "A/B/c", Qos: 1},
 | |
| 			wasNew:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:   "wildcard+ sub",
 | |
| 			client: "cl1",
 | |
| 
 | |
| 			subscription: packets.Subscription{Filter: "d/+"},
 | |
| 			wasNew:       true,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:         "wildcard# sub",
 | |
| 			client:       "cl1",
 | |
| 			subscription: packets.Subscription{Filter: "d/e/#"},
 | |
| 			wasNew:       true,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	index := NewTopicsIndex()
 | |
| 	for _, tx := range tt {
 | |
| 		t.Run(tx.desc, func(t *testing.T) {
 | |
| 			require.Equal(t, tx.wasNew, index.Subscribe(tx.client, tx.subscription))
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	final := index.root.particles.get("a").particles.get("b").particles.get("c")
 | |
| 	require.NotNil(t, final)
 | |
| 	client, exists := final.subscriptions.Get("cl1")
 | |
| 	require.True(t, exists)
 | |
| 	require.Equal(t, byte(1), client.Qos)
 | |
| }
 | |
| 
 | |
| func TestSubscribeShared(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c", Qos: 2})
 | |
| 	final := index.root.particles.get("a").particles.get("b").particles.get("c")
 | |
| 	require.NotNil(t, final)
 | |
| 	client, exists := final.shared.Get("tmp", "cl1")
 | |
| 	require.True(t, exists)
 | |
| 	require.Equal(t, byte(2), client.Qos)
 | |
| 	require.Equal(t, 0, final.subscriptions.Len())
 | |
| 	require.Equal(t, 1, final.shared.Len())
 | |
| }
 | |
| 
 | |
| func BenchmarkSubscribe(b *testing.B) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		index.Subscribe("client-1", packets.Subscription{Filter: "a/b/c"})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func BenchmarkSubscribeShared(b *testing.B) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		index.Subscribe("client-1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c"})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestUnsubscribe(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/d", Qos: 1})
 | |
| 	client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").subscriptions.Get("cl1")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| 
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "a/b/+/d", Qos: 1})
 | |
| 	client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| 
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "d/e/f", Qos: 1})
 | |
| 	client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl1")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| 
 | |
| 	index.Subscribe("cl2", packets.Subscription{Filter: "d/e/f", Qos: 1})
 | |
| 	client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| 
 | |
| 	index.Subscribe("cl3", packets.Subscription{Filter: "#", Qos: 2})
 | |
| 	client, exists = index.root.particles.get("#").subscriptions.Get("cl3")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| 
 | |
| 	ok := index.Unsubscribe("a/b/c/d", "cl1")
 | |
| 	require.True(t, ok)
 | |
| 	require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
 | |
| 	client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| 
 | |
| 	ok = index.Unsubscribe("d/e/f", "cl1")
 | |
| 	require.True(t, ok)
 | |
| 
 | |
| 	require.Equal(t, 1, index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Len())
 | |
| 	client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| 
 | |
| 	ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "nobody")
 | |
| 	require.False(t, ok)
 | |
| }
 | |
| 
 | |
| func TestUnsubscribeNoCascade(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/e/e"})
 | |
| 
 | |
| 	ok := index.Unsubscribe("a/b/c/e/e", "cl1")
 | |
| 	require.True(t, ok)
 | |
| 	require.Equal(t, 1, index.root.particles.len())
 | |
| 
 | |
| 	client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").subscriptions.Get("cl1")
 | |
| 	require.NotNil(t, client)
 | |
| 	require.True(t, exists)
 | |
| }
 | |
| 
 | |
| func TestUnsubscribeShared(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c", Qos: 2})
 | |
| 	final := index.root.particles.get("a").particles.get("b").particles.get("c")
 | |
| 	require.NotNil(t, final)
 | |
| 	client, exists := final.shared.Get("tmp", "cl1")
 | |
| 	require.True(t, exists)
 | |
| 	require.Equal(t, byte(2), client.Qos)
 | |
| 
 | |
| 	require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
 | |
| 	_, exists = final.shared.Get("tmp", "cl1")
 | |
| 	require.False(t, exists)
 | |
| }
 | |
| 
 | |
| func BenchmarkUnsubscribe(b *testing.B) {
 | |
| 	index := NewTopicsIndex()
 | |
| 
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		b.StopTimer()
 | |
| 		index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
 | |
| 		b.StartTimer()
 | |
| 		index.Unsubscribe("a/b/c", "cl1")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestIndexSeek(t *testing.T) {
 | |
| 	filter := "a/b/c/d/e/f"
 | |
| 	index := NewTopicsIndex()
 | |
| 	k1 := index.set(filter, 0)
 | |
| 	require.Equal(t, "f", k1.key)
 | |
| 	k1.subscriptions.Add("cl1", packets.Subscription{})
 | |
| 
 | |
| 	require.Equal(t, k1, index.seek(filter, 0))
 | |
| 	require.Nil(t, index.seek("d/e/f", 0))
 | |
| }
 | |
| 
 | |
| func TestIndexTrim(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	k1 := index.set("a/b/c", 0)
 | |
| 	require.Equal(t, "c", k1.key)
 | |
| 	k1.subscriptions.Add("cl1", packets.Subscription{})
 | |
| 
 | |
| 	k2 := index.set("a/b/c/d/e/f", 0)
 | |
| 	require.Equal(t, "f", k2.key)
 | |
| 	k2.subscriptions.Add("cl1", packets.Subscription{})
 | |
| 
 | |
| 	k3 := index.set("a/b", 0)
 | |
| 	require.Equal(t, "b", k3.key)
 | |
| 	k3.subscriptions.Add("cl1", packets.Subscription{})
 | |
| 
 | |
| 	index.trim(k2)
 | |
| 	require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
 | |
| 	require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").particles.get("e").particles.get("f"))
 | |
| 	require.NotNil(t, index.root.particles.get("a").particles.get("b"))
 | |
| 
 | |
| 	k2.subscriptions.Delete("cl1")
 | |
| 	index.trim(k2)
 | |
| 
 | |
| 	require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d"))
 | |
| 	require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
 | |
| 
 | |
| 	k1.subscriptions.Delete("cl1")
 | |
| 	k3.subscriptions.Delete("cl1")
 | |
| 	index.trim(k2)
 | |
| 	require.Nil(t, index.root.particles.get("a"))
 | |
| }
 | |
| 
 | |
| func TestIndexSet(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	child := index.set("a/b/c", 0)
 | |
| 	require.Equal(t, "c", child.key)
 | |
| 	require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
 | |
| 
 | |
| 	child = index.set("a/b/c/d/e", 0)
 | |
| 	require.Equal(t, "e", child.key)
 | |
| 
 | |
| 	child = index.set("a/b/c/c/a", 0)
 | |
| 	require.Equal(t, "a", child.key)
 | |
| }
 | |
| 
 | |
| func TestIndexSetPrefixed(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	child := index.set("/c", 0)
 | |
| 	require.Equal(t, "c", child.key)
 | |
| 	require.NotNil(t, index.root.particles.get("").particles.get("c"))
 | |
| }
 | |
| 
 | |
| func BenchmarkIndexSet(b *testing.B) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		index.set("a/b/c", 0)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestRetainMessage(t *testing.T) {
 | |
| 	pk := packets.Packet{
 | |
| 		FixedHeader: packets.FixedHeader{Retain: true},
 | |
| 		TopicName:   "a/b/c",
 | |
| 		Payload:     []byte("hello"),
 | |
| 	}
 | |
| 
 | |
| 	index := NewTopicsIndex()
 | |
| 	r := index.RetainMessage(pk)
 | |
| 	require.Equal(t, int64(1), r)
 | |
| 	pke, ok := index.Retained.Get(pk.TopicName)
 | |
| 	require.True(t, ok)
 | |
| 	require.Equal(t, pk, pke)
 | |
| 
 | |
| 	pk2 := packets.Packet{
 | |
| 		FixedHeader: packets.FixedHeader{Retain: true},
 | |
| 		TopicName:   "a/b/d/f",
 | |
| 		Payload:     []byte("hello"),
 | |
| 	}
 | |
| 	r = index.RetainMessage(pk2)
 | |
| 	require.Equal(t, int64(1), r)
 | |
| 	// The same message already exists, but we're not doing a deep-copy check, so it's considered to be a new message.
 | |
| 	r = index.RetainMessage(pk2)
 | |
| 	require.Equal(t, int64(1), r)
 | |
| 
 | |
| 	// Clear existing retained
 | |
| 	pk3 := packets.Packet{TopicName: "a/b/c", Payload: []byte{}}
 | |
| 	r = index.RetainMessage(pk3)
 | |
| 	require.Equal(t, int64(-1), r)
 | |
| 	_, ok = index.Retained.Get(pk.TopicName)
 | |
| 	require.False(t, ok)
 | |
| 
 | |
| 	// Clear no retained
 | |
| 	r = index.RetainMessage(pk3)
 | |
| 	require.Equal(t, int64(0), r)
 | |
| }
 | |
| 
 | |
| func BenchmarkRetainMessage(b *testing.B) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestIsolateParticle(t *testing.T) {
 | |
| 	particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
 | |
| 	require.Equal(t, "path", particle)
 | |
| 	require.Equal(t, true, hasNext)
 | |
| 	particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
 | |
| 	require.Equal(t, "to", particle)
 | |
| 	require.Equal(t, true, hasNext)
 | |
| 	particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
 | |
| 	require.Equal(t, "my", particle)
 | |
| 	require.Equal(t, true, hasNext)
 | |
| 	particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
 | |
| 	require.Equal(t, "mqtt", particle)
 | |
| 	require.Equal(t, false, hasNext)
 | |
| 
 | |
| 	particle, hasNext = isolateParticle("/path/", 0)
 | |
| 	require.Equal(t, "", particle)
 | |
| 	require.Equal(t, true, hasNext)
 | |
| 	particle, hasNext = isolateParticle("/path/", 1)
 | |
| 	require.Equal(t, "path", particle)
 | |
| 	require.Equal(t, true, hasNext)
 | |
| 	particle, hasNext = isolateParticle("/path/", 2)
 | |
| 	require.Equal(t, "", particle)
 | |
| 	require.Equal(t, false, hasNext)
 | |
| 
 | |
| 	particle, hasNext = isolateParticle("a/b/c/+/+", 3)
 | |
| 	require.Equal(t, "+", particle)
 | |
| 	require.Equal(t, true, hasNext)
 | |
| 	particle, hasNext = isolateParticle("a/b/c/+/+", 4)
 | |
| 	require.Equal(t, "+", particle)
 | |
| 	require.Equal(t, false, hasNext)
 | |
| }
 | |
| 
 | |
| func BenchmarkIsolateParticle(b *testing.B) {
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		isolateParticle("path/to/my/mqtt", 3)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestScanSubscribers(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c", Identifier: 22})
 | |
| 	index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c/d/e/f"})
 | |
| 	index.Subscribe("cl1", packets.Subscription{Qos: 2, Filter: "a/b/c/d/+/f"})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/#"})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 1, Filter: "a/b/c"})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "a/b/+", Identifier: 77})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "d/e/f", Identifier: 7237})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "$SYS/uptime", Identifier: 3})
 | |
| 	index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: "+/b", Identifier: 234})
 | |
| 	index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: "#", Identifier: 5})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
 | |
| 
 | |
| 	subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
 | |
| 	require.Equal(t, 4, len(subs.Subscriptions))
 | |
| 	require.Contains(t, subs.Subscriptions, "cl1")
 | |
| 	require.Contains(t, subs.Subscriptions, "cl2")
 | |
| 	require.Contains(t, subs.Subscriptions, "cl3")
 | |
| 	require.Contains(t, subs.Subscriptions, "cl4")
 | |
| 
 | |
| 	require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
 | |
| 	require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
 | |
| 	require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
 | |
| 	require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
 | |
| 
 | |
| 	require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
 | |
| 	require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
 | |
| 	require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
 | |
| 	require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
 | |
| 	require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
 | |
| 	require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
 | |
| 
 | |
| 	subs = index.scanSubscribers("", 0, nil, new(Subscribers))
 | |
| 	require.Equal(t, 0, len(subs.Subscriptions))
 | |
| }
 | |
| 
 | |
| func TestScanSubscribersShared(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
 | |
| 	index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
 | |
| 	index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
 | |
| 	index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
 | |
| 	subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
 | |
| 	require.Equal(t, 3, len(subs.Shared))
 | |
| }
 | |
| 
 | |
| func TestSelectSharedSubscriber(t *testing.T) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110})
 | |
| 	index.Subscribe("cl1b", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
 | |
| 	index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
 | |
| 	subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
 | |
| 	require.Equal(t, 2, len(subs.Shared))
 | |
| 	require.Contains(t, subs.Shared, SharePrefix+"/tmp/a/b/c")
 | |
| 	require.Contains(t, subs.Shared, SharePrefix+"/tmp2/a/b/c")
 | |
| 	require.Len(t, subs.Shared[SharePrefix+"/tmp/a/b/c"], 3)
 | |
| 	require.Len(t, subs.Shared[SharePrefix+"/tmp2/a/b/c"], 1)
 | |
| 	subs.SelectShared()
 | |
| 	require.Len(t, subs.SharedSelected, 2)
 | |
| }
 | |
| 
 | |
| func TestMergeSharedSelected(t *testing.T) {
 | |
| 	s := &Subscribers{
 | |
| 		SharedSelected: map[string]packets.Subscription{
 | |
| 			"cl1": {Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110},
 | |
| 			"cl2": {Qos: 1, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 111},
 | |
| 		},
 | |
| 		Subscriptions: map[string]packets.Subscription{
 | |
| 			"cl2": {Qos: 1, Filter: "a/b/c", Identifier: 112},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	s.MergeSharedSelected()
 | |
| 
 | |
| 	require.Equal(t, 2, len(s.Subscriptions))
 | |
| 	require.Contains(t, s.Subscriptions, "cl1")
 | |
| 	require.Contains(t, s.Subscriptions, "cl2")
 | |
| 	require.EqualValues(t, map[string]int{
 | |
| 		SharePrefix + "/tmp2/a/b/c": 111,
 | |
| 		"a/b/c":                     112,
 | |
| 	}, s.Subscriptions["cl2"].Identifiers)
 | |
| }
 | |
| 
 | |
| func TestSubscribersFind(t *testing.T) {
 | |
| 	tt := []struct {
 | |
| 		filter  string
 | |
| 		topic   string
 | |
| 		matched bool
 | |
| 	}{
 | |
| 		{filter: "a", topic: "a", matched: true},
 | |
| 		{filter: "a/", topic: "a", matched: false},
 | |
| 		{filter: "a/", topic: "a/", matched: true},
 | |
| 		{filter: "/a", topic: "/a", matched: true},
 | |
| 		{filter: "path/to/my/mqtt", topic: "path/to/my/mqtt", matched: true},
 | |
| 		{filter: "path/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
 | |
| 		{filter: "+/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
 | |
| 		{filter: "#", topic: "path/to/my/mqtt", matched: true},
 | |
| 		{filter: "+/+/+/+", topic: "path/to/my/mqtt", matched: true},
 | |
| 		{filter: "+/+/+/#", topic: "path/to/my/mqtt", matched: true},
 | |
| 		{filter: "zen/#", topic: "zen", matched: true}, // as per 4.7.1.2
 | |
| 		{filter: "trailing-end/#", topic: "trailing-end/", matched: true},
 | |
| 		{filter: "+/prefixed", topic: "/prefixed", matched: true},
 | |
| 		{filter: "+/+/#", topic: "path/to/my/mqtt", matched: true},
 | |
| 		{filter: "path/to/", topic: "path/to/my/mqtt", matched: false},
 | |
| 		{filter: "#/stuff", topic: "path/to/my/mqtt", matched: false},
 | |
| 		{filter: "#", topic: "$SYS/info", matched: false},
 | |
| 		{filter: "$SYS/#", topic: "$SYS/info", matched: true},
 | |
| 		{filter: "+/info", topic: "$SYS/info", matched: false},
 | |
| 	}
 | |
| 
 | |
| 	for _, tx := range tt {
 | |
| 		t.Run("filter:'"+tx.filter+"' vs topic:'"+tx.topic+"'", func(t *testing.T) {
 | |
| 			index := NewTopicsIndex()
 | |
| 			index.Subscribe("cl1", packets.Subscription{Filter: tx.filter})
 | |
| 			subs := index.Subscribers(tx.topic)
 | |
| 			require.Equal(t, tx.matched, len(subs.Subscriptions) == 1)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func BenchmarkSubscribers(b *testing.B) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "a/+/c"})
 | |
| 	index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/+"})
 | |
| 	index.Subscribe("cl2", packets.Subscription{Filter: "a/b/c/d"})
 | |
| 	index.Subscribe("cl3", packets.Subscription{Filter: "#"})
 | |
| 
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		index.Subscribers("a/b/c")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestMessagesPattern(t *testing.T) {
 | |
| 	payload := []byte("hello")
 | |
| 	fh := packets.FixedHeader{Type: packets.Publish, Retain: true}
 | |
| 
 | |
| 	pks := []packets.Packet{
 | |
| 		{TopicName: "$SYS/uptime", Payload: payload, FixedHeader: fh},
 | |
| 		{TopicName: "$SYS/info", Payload: payload, FixedHeader: fh},
 | |
| 		{TopicName: "a/b/c/d", Payload: payload, FixedHeader: fh},
 | |
| 		{TopicName: "a/b/c/e", Payload: payload, FixedHeader: fh},
 | |
| 		{TopicName: "a/b/d/f", Payload: payload, FixedHeader: fh},
 | |
| 		{TopicName: "q/w/e/r/t/y", Payload: payload, FixedHeader: fh},
 | |
| 		{TopicName: "q/x/e/r/t/o", Payload: payload, FixedHeader: fh},
 | |
| 		{TopicName: "asdf", Payload: payload, FixedHeader: fh},
 | |
| 	}
 | |
| 
 | |
| 	tt := []struct {
 | |
| 		filter string
 | |
| 		len    int
 | |
| 	}{
 | |
| 		{"a/b/c/d", 1},
 | |
| 		{"$SYS/+", 2},
 | |
| 		{"$SYS/#", 2},
 | |
| 		{"#", len(pks) - 2},
 | |
| 		{"a/b/c/+", 2},
 | |
| 		{"a/+/c/+", 2},
 | |
| 		{"+/+/+/d", 1},
 | |
| 		{"q/w/e/#", 1},
 | |
| 		{"+/+/+/+", 3},
 | |
| 		{"q/#", 2},
 | |
| 		{"asdf", 1},
 | |
| 		{"", 0},
 | |
| 		{"#", 6},
 | |
| 	}
 | |
| 
 | |
| 	index := NewTopicsIndex()
 | |
| 	for _, pk := range pks {
 | |
| 		index.RetainMessage(pk)
 | |
| 	}
 | |
| 
 | |
| 	for _, tx := range tt {
 | |
| 		t.Run("filter:'"+tx.filter, func(t *testing.T) {
 | |
| 			messages := index.Messages(tx.filter)
 | |
| 			require.Equal(t, tx.len, len(messages))
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func BenchmarkMessages(b *testing.B) {
 | |
| 	index := NewTopicsIndex()
 | |
| 	index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
 | |
| 	index.RetainMessage(packets.Packet{TopicName: "a/b/d/e/f"})
 | |
| 	index.RetainMessage(packets.Packet{TopicName: "d/e/f/g"})
 | |
| 	index.RetainMessage(packets.Packet{TopicName: "$SYS/info"})
 | |
| 	index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
 | |
| 
 | |
| 	for n := 0; n < b.N; n++ {
 | |
| 		index.Messages("+/b/c/+")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestNewParticles(t *testing.T) {
 | |
| 	cl := newParticles()
 | |
| 	require.NotNil(t, cl.internal)
 | |
| }
 | |
| 
 | |
| func TestParticlesAdd(t *testing.T) {
 | |
| 	p := newParticles()
 | |
| 	p.add(&particle{key: "a"})
 | |
| 	require.Contains(t, p.internal, "a")
 | |
| }
 | |
| 
 | |
| func TestParticlesGet(t *testing.T) {
 | |
| 	p := newParticles()
 | |
| 	p.add(&particle{key: "a"})
 | |
| 	p.add(&particle{key: "b"})
 | |
| 	require.Contains(t, p.internal, "a")
 | |
| 	require.Contains(t, p.internal, "b")
 | |
| 
 | |
| 	particle := p.get("a")
 | |
| 	require.NotNil(t, particle)
 | |
| 	require.Equal(t, "a", particle.key)
 | |
| }
 | |
| 
 | |
| func TestParticlesGetAll(t *testing.T) {
 | |
| 	p := newParticles()
 | |
| 	p.add(&particle{key: "a"})
 | |
| 	p.add(&particle{key: "b"})
 | |
| 	p.add(&particle{key: "c"})
 | |
| 	require.Contains(t, p.internal, "a")
 | |
| 	require.Contains(t, p.internal, "b")
 | |
| 	require.Contains(t, p.internal, "c")
 | |
| 
 | |
| 	particles := p.getAll()
 | |
| 	require.Len(t, particles, 3)
 | |
| }
 | |
| 
 | |
| func TestParticlesLen(t *testing.T) {
 | |
| 	p := newParticles()
 | |
| 	p.add(&particle{key: "a"})
 | |
| 	p.add(&particle{key: "b"})
 | |
| 	require.Contains(t, p.internal, "a")
 | |
| 	require.Contains(t, p.internal, "b")
 | |
| 	require.Equal(t, 2, p.len())
 | |
| }
 | |
| 
 | |
| func TestParticlesDelete(t *testing.T) {
 | |
| 	p := newParticles()
 | |
| 	p.add(&particle{key: "a"})
 | |
| 	require.Contains(t, p.internal, "a")
 | |
| 
 | |
| 	p.delete("a")
 | |
| 	particle := p.get("a")
 | |
| 	require.Nil(t, particle)
 | |
| }
 | |
| 
 | |
| func TestIsValid(t *testing.T) {
 | |
| 	require.True(t, IsValidFilter("a/b/c", false))
 | |
| 	require.True(t, IsValidFilter("a/b//c", false))
 | |
| 	require.True(t, IsValidFilter("$SYS", false))
 | |
| 	require.True(t, IsValidFilter("$SYS/info", false))
 | |
| 	require.True(t, IsValidFilter("$sys/info", false))
 | |
| 	require.True(t, IsValidFilter("abc/#", false))
 | |
| 	require.False(t, IsValidFilter("", false))
 | |
| 	require.False(t, IsValidFilter(SharePrefix, false))
 | |
| 	require.False(t, IsValidFilter(SharePrefix+"/", false))
 | |
| 	require.False(t, IsValidFilter(SharePrefix+"/b+/", false))
 | |
| 	require.False(t, IsValidFilter(SharePrefix+"/+", false))
 | |
| 	require.False(t, IsValidFilter(SharePrefix+"/#", false))
 | |
| 	require.False(t, IsValidFilter(SharePrefix+"/#/", false))
 | |
| 	require.False(t, IsValidFilter("a/#/c", false))
 | |
| }
 | |
| 
 | |
| func TestIsValidForPublish(t *testing.T) {
 | |
| 	require.True(t, IsValidFilter("", true))
 | |
| 	require.True(t, IsValidFilter("a/b/c", true))
 | |
| 	require.False(t, IsValidFilter("a/b/+/d", true))
 | |
| 	require.False(t, IsValidFilter("a/b/#", true))
 | |
| 	require.False(t, IsValidFilter("$SYS/info", true))
 | |
| }
 | |
| 
 | |
| func TestIsSharedFilter(t *testing.T) {
 | |
| 	require.True(t, IsSharedFilter(SharePrefix+"/tmp/a/b/c"))
 | |
| 	require.False(t, IsSharedFilter("a/b/c"))
 | |
| }
 | |
| 
 | |
| func TestNewInboundAliases(t *testing.T) {
 | |
| 	a := NewInboundTopicAliases(5)
 | |
| 	require.NotNil(t, a)
 | |
| 	require.NotNil(t, a.internal)
 | |
| 	require.Equal(t, uint16(5), a.maximum)
 | |
| }
 | |
| 
 | |
| func TestInboundAliasesSet(t *testing.T) {
 | |
| 	topic := "test"
 | |
| 	id := uint16(1)
 | |
| 	a := NewInboundTopicAliases(5)
 | |
| 	require.Equal(t, topic, a.Set(id, topic))
 | |
| 	require.Contains(t, a.internal, id)
 | |
| 	require.Equal(t, a.internal[id], topic)
 | |
| 
 | |
| 	require.Equal(t, topic, a.Set(id, ""))
 | |
| }
 | |
| 
 | |
| func TestInboundAliasesSetMaxZero(t *testing.T) {
 | |
| 	topic := "test"
 | |
| 	id := uint16(1)
 | |
| 	a := NewInboundTopicAliases(0)
 | |
| 	require.Equal(t, topic, a.Set(id, topic))
 | |
| 	require.NotContains(t, a.internal, id)
 | |
| }
 | |
| 
 | |
| func TestNewOutboundAliases(t *testing.T) {
 | |
| 	a := NewOutboundTopicAliases(5)
 | |
| 	require.NotNil(t, a)
 | |
| 	require.NotNil(t, a.internal)
 | |
| 	require.Equal(t, uint16(5), a.maximum)
 | |
| 	require.Equal(t, uint32(0), a.cursor)
 | |
| }
 | |
| 
 | |
| func TestOutboundAliasesSet(t *testing.T) {
 | |
| 	a := NewOutboundTopicAliases(3)
 | |
| 	n, ok := a.Set("t1")
 | |
| 	require.False(t, ok)
 | |
| 	require.Equal(t, uint16(1), n)
 | |
| 
 | |
| 	n, ok = a.Set("t2")
 | |
| 	require.False(t, ok)
 | |
| 	require.Equal(t, uint16(2), n)
 | |
| 
 | |
| 	n, ok = a.Set("t3")
 | |
| 	require.False(t, ok)
 | |
| 	require.Equal(t, uint16(3), n)
 | |
| 
 | |
| 	n, ok = a.Set("t4")
 | |
| 	require.False(t, ok)
 | |
| 	require.Equal(t, uint16(0), n)
 | |
| 
 | |
| 	n, ok = a.Set("t2")
 | |
| 	require.True(t, ok)
 | |
| 	require.Equal(t, uint16(2), n)
 | |
| }
 | |
| 
 | |
| func TestOutboundAliasesSetMaxZero(t *testing.T) {
 | |
| 	topic := "test"
 | |
| 	a := NewOutboundTopicAliases(0)
 | |
| 	n, ok := a.Set(topic)
 | |
| 	require.False(t, ok)
 | |
| 	require.Equal(t, uint16(0), n)
 | |
| }
 | |
| 
 | |
| func TestNewTopicAliases(t *testing.T) {
 | |
| 	a := NewTopicAliases(5)
 | |
| 	require.NotNil(t, a.Inbound)
 | |
| 	require.Equal(t, uint16(5), a.Inbound.maximum)
 | |
| 	require.NotNil(t, a.Outbound)
 | |
| 	require.Equal(t, uint16(5), a.Outbound.maximum)
 | |
| }
 | 
