From b4332150f83ab853b309a2629c753217a16fd43d Mon Sep 17 00:00:00 2001 From: thedevop <60499013+thedevop@users.noreply.github.com> Date: Sat, 1 Mar 2025 05:50:37 -0800 Subject: [PATCH] Improve message expiry (#460) --- clients.go | 6 +++++- server.go | 30 +++++++++++++++++++++++++++--- server_test.go | 11 +++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/clients.go b/clients.go index 0aa400c..4921285 100644 --- a/clients.go +++ b/clients.go @@ -533,7 +533,11 @@ func (cl *Client) WritePacket(pk packets.Packet) error { } if pk.Expiry > 0 { - pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6] + expiry := pk.Expiry - time.Now().Unix() + if expiry < 1 { + expiry = 1 + } + pk.Properties.MessageExpiryInterval = uint32(expiry) // [MQTT-3.3.2-6] } pk.ProtocolVersion = cl.Properties.ProtocolVersion diff --git a/server.go b/server.go index 7153ca1..9181167 100644 --- a/server.go +++ b/server.go @@ -885,6 +885,11 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { pk.Origin = cl.ID pk.Created = time.Now().Unix() + if expiry := minimum(s.Options.Capabilities.MaximumMessageExpiryInterval, + int64(pk.Properties.MessageExpiryInterval)); expiry > 0 { + pk.Expiry = pk.Created + expiry + } + if !cl.Net.Inline { if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok { if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] @@ -986,9 +991,11 @@ func (s *Server) publishToSubscribers(pk packets.Packet) { pk.Created = time.Now().Unix() } - pk.Expiry = pk.Created + s.Options.Capabilities.MaximumMessageExpiryInterval - if pk.Properties.MessageExpiryInterval > 0 { - pk.Expiry = pk.Created + int64(pk.Properties.MessageExpiryInterval) + if pk.Expiry == 0 { + if expiry := minimum(s.Options.Capabilities.MaximumMessageExpiryInterval, + int64(pk.Properties.MessageExpiryInterval)); expiry > 0 { + pk.Expiry = pk.Created + expiry + } } subscribers := s.Topics.Subscribers(pk.TopicName) @@ -1755,3 +1762,20 @@ func (s *Server) sendDelayedLWT(dt int64) { func Int64toa(v int64) string { return strconv.FormatInt(v, 10) } + +// minimum differs from built-in min, it returns minimum of the non-zero value a and b. +// If both a and b are zero value, it reutrns 0. +func minimum(a, b int64) (m int64) { + if a != 0 { + m = a + if b != 0 && b < a { + m = b + } + return + } + + if b != 0 { + m = b + } + return +} diff --git a/server_test.go b/server_test.go index 35e5139..0297f80 100644 --- a/server_test.go +++ b/server_test.go @@ -3920,3 +3920,14 @@ func TestServerSubscribeWithRetainDifferentIdentifier(t *testing.T) { require.Equal(t, true, <-finishCh) } } + +func TestMinimum(t *testing.T) { + require.EqualValues(t, 0, minimum(0, 0)) + require.EqualValues(t, 1, minimum(0, 1)) + require.EqualValues(t, 1, minimum(1, 0)) + require.EqualValues(t, 10, minimum(10, 20)) + require.EqualValues(t, 20, minimum(30, 20)) + require.EqualValues(t, -1, minimum(-1, 0)) // negative values are not used, but included here for completeness + require.EqualValues(t, -1, minimum(-1, 20)) + require.EqualValues(t, -2, minimum(-1, -2)) +}