diff --git a/clients.go b/clients.go index dacff7e..0aa400c 100644 --- a/clients.go +++ b/clients.go @@ -150,7 +150,7 @@ type ClientState struct { disconnected int64 // the time the client disconnected in unix time, for calculating expiry outbound chan *packets.Packet // queue for pending outbound packets endOnce sync.Once // only end once - isTakenOver uint32 // used to identify orphaned clients + isTakenOver atomic.Bool // used to identify orphaned clients packetID uint32 // the current highest packetID open context.Context // indicate that the client is open for packet exchange cancelOpen context.CancelFunc // cancel function for open context @@ -427,6 +427,10 @@ func (cl *Client) Closed() bool { return cl.State.open == nil || cl.State.open.Err() != nil } +func (cl *Client) IsTakenOver() bool { + return cl.State.isTakenOver.Load() +} + // ReadFixedHeader reads in the values of the next packet's fixed header. func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { if cl.Net.bconn == nil { diff --git a/clients_test.go b/clients_test.go index c5508fc..af93390 100644 --- a/clients_test.go +++ b/clients_test.go @@ -599,6 +599,13 @@ func TestClientClosed(t *testing.T) { require.True(t, cl.Closed()) } +func TestClientIsTakenOver(t *testing.T) { + cl, _, _ := newTestClient() + require.False(t, cl.IsTakenOver()) + cl.State.isTakenOver.Store(true) + require.True(t, cl.IsTakenOver()) +} + func TestClientReadFixedHeaderError(t *testing.T) { cl, r, _ := newTestClient() defer cl.Stop(errClientStop) diff --git a/hooks.go b/hooks.go index 4da709f..28d78e8 100644 --- a/hooks.go +++ b/hooks.go @@ -405,6 +405,8 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er "hook", hook.ID(), "packet", pkx) return pk, err + } else if errors.Is(err, packets.CodeSuccessIgnore) { + return pk, err } h.Log.Error("publish packet error", "error", err, diff --git a/server.go b/server.go index 5bcd5cc..b1163a3 100644 --- a/server.go +++ b/server.go @@ -485,7 +485,7 @@ func (s *Server) attachClient(cl *Client, listener string) error { expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) s.hooks.OnDisconnect(cl, err, expire) - if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 { + if expire && !cl.IsTakenOver() { cl.ClearInflights() s.UnsubscribeClient(cl) s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23] @@ -565,11 +565,11 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] s.UnsubscribeClient(existing) existing.ClearInflights() - atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred - return false // [MQTT-3.2.2-3] + existing.State.isTakenOver.Store(true) // only set isTakenOver after unsubscribe has occurred + return false // [MQTT-3.2.2-3] } - atomic.StoreUint32(&existing.State.isTakenOver, 1) + existing.State.isTakenOver.Store(true) if existing.State.Inflight.Len() > 0 { cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5] if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 { @@ -1358,7 +1358,7 @@ func (s *Server) UnsubscribeClient(cl *Client) { cl.State.Subscriptions.Delete(k) } - if atomic.LoadUint32(&cl.State.isTakenOver) == 1 { + if cl.IsTakenOver() { return } diff --git a/server_test.go b/server_test.go index ab0830a..35e5139 100644 --- a/server_test.go +++ b/server_test.go @@ -667,6 +667,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) require.NotEmpty(t, clw.State.Subscriptions) + require.True(t, cl.IsTakenOver()) // Prevent sequential takeover memory-bloom. require.Empty(t, cl.State.Subscriptions.GetAll()) @@ -761,6 +762,9 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { _, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) require.NoError(t, <-o2) + + require.True(t, clp1.IsTakenOver()) + require.False(t, clp2.IsTakenOver()) } func TestEstablishConnectionResentPendingInflightsError(t *testing.T) { @@ -848,12 +852,15 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv) require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover) + require.True(t, cl.IsTakenOver()) + _ = w.Close() _ = r.Close() clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier) require.True(t, ok) require.Equal(t, 0, clw.State.Subscriptions.Len()) + } func TestEstablishConnectionBadAuthentication(t *testing.T) {