Correctly identify and clean taken-over sessions (#180)

This commit is contained in:
JB
2023-02-25 01:24:17 +00:00
committed by GitHub
parent a909d30923
commit 9b7a943888
6 changed files with 170 additions and 23 deletions

View File

@@ -142,6 +142,7 @@ type ClientState struct {
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
outboundQty int32 // number of messages currently in the outbound queue

View File

@@ -58,6 +58,18 @@ func (i *Inflight) Len() int {
return len(i.internal)
}
// Clone returns a new instance of Inflight with the same message data.
// This is used when transferring inflights from a taken-over session.
func (i *Inflight) Clone() *Inflight {
c := NewInflights()
i.RLock()
defer i.RUnlock()
for k, v := range i.internal {
c.internal[k] = v
}
return c
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()

View File

@@ -61,6 +61,16 @@ func TestInflightLen(t *testing.T) {
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightClone(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
cloned := cl.State.Inflight.Clone()
require.NotNil(t, cloned)
require.NotSame(t, cloned, cl.State.Inflight)
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newTestClient()

View File

@@ -250,26 +250,26 @@ var TPacketData = map[byte]TPacketCases{
Desc: "mqtt v3.1.1",
Primary: true,
RawBytes: []byte{
Connect << 4, 16, // Fixed header
Connect << 4, 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Packet Flags
0, 60, // Keepalive
0, 4, // Client ID - MSB+LSB
'z', 'e', 'n', '3', // Client ID "zen"
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connect,
Remaining: 16,
Remaining: 15,
},
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: false,
Keepalive: 60,
ClientIdentifier: "zen3",
ClientIdentifier: "zen",
},
},
},
@@ -426,9 +426,9 @@ var TPacketData = map[byte]TPacketCases{
Connect << 4, 28, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
194, // Packet Flags
0, 20, // Keepalive
4, // Protocol Version
0 | 1<<6 | 1<<7, // Packet Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 5, // Username MSB+LSB
@@ -444,7 +444,7 @@ var TPacketData = map[byte]TPacketCases{
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: true,
Clean: false,
Keepalive: 20,
ClientIdentifier: "zen",
UsernameFlag: true,

View File

@@ -26,7 +26,7 @@ import (
)
const (
Version = "2.2.3" // the current server version.
Version = "2.2.4" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
)
@@ -353,9 +353,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
if err != nil {
s.sendLWT(cl)
cl.Stop(err)
}
if err == nil {
} else {
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
}
@@ -365,9 +363,11 @@ func (s *Server) attachClient(cl *Client, listener string) error {
close(cl.State.outbound)
if expire {
s.UnsubscribeClient(cl)
cl.ClearInflights(math.MaxInt64, 0)
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
s.UnsubscribeClient(cl)
if atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
}
}
return err
@@ -440,17 +440,17 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
return false // [MQTT-3.2.2-3]
}
atomic.StoreUint32(&existing.State.isTakenOver, 1)
if existing.State.Inflight.Len() > 0 {
existing.State.Inflight.Lock()
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
existing.State.Inflight.Unlock()
cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.capabilities.ReceiveMaximum != 0 {
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
@@ -465,6 +465,15 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
cl.State.Subscriptions.Add(sub.Filter, sub)
}
// Clean the state of the existing client to prevent sequential take-overs
// from increasing memory usage by inflights + subs * client-id.
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
s.Log.Debug().Str("client", cl.ID).
Str("old_remote", existing.Net.Remote).
Str("new_remote", cl.Net.Remote).
Msg("session taken over")
return true // [MQTT-3.2.2-3]
}
@@ -1087,8 +1096,15 @@ func (s *Server) UnsubscribeClient(cl *Client) {
i := 0
filterMap := cl.State.Subscriptions.GetAll()
filters := make([]packets.Subscription, len(filterMap))
for k, v := range filterMap {
for k := range filterMap {
cl.State.Subscriptions.Delete(k)
}
if atomic.LoadUint32(&cl.State.isTakenOver) == 1 {
return
}
for k, v := range filterMap {
if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.Info.Subscriptions, -1)
}

View File

@@ -48,6 +48,23 @@ func (h *AllowHook) Provides(b byte) bool {
func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true }
func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true }
type DelayHook struct {
HookBase
DisconnectDelay time.Duration
}
func (h *DelayHook) ID() string {
return "delay-hook"
}
func (h *DelayHook) Provides(b byte) bool {
return bytes.Contains([]byte{OnDisconnect}, []byte{b})
}
func (h *DelayHook) OnDisconnect(cl *Client, err error, expire bool) {
time.Sleep(h.DisconnectDelay)
}
func newServer() *Server {
cc := *DefaultServerCapabilities
cc.MaximumMessageExpiryInterval = 0
@@ -401,6 +418,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
cl, r0, _ := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.Properties.Username = []byte("mochi")
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
@@ -456,9 +474,99 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
require.True(t, ok)
require.NotEmpty(t, clw.State.Subscriptions)
sub, ok := cl.State.Subscriptions.Get("a/b/c")
// Prevent sequential takeover memory-bloom.
require.Empty(t, cl.State.Subscriptions.GetAll())
}
// See https://github.com/mochi-co/mqtt/issues/173
func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) {
s := newServer()
d := new(DelayHook)
d.DisconnectDelay = time.Millisecond * 200
s.AddHook(d, nil)
defer s.Close()
// Clean session, 0 session expiry interval
cl1RawBytes := []byte{
packets.Connect << 4, 21, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
1 << 1, // Packet Flags
0, 30, // Keepalive
5, // Properties length
17, 0, 0, 0, 0, // Session Expiry Interval (17)
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
}
// Make first connection
r1, w1 := net.Pipe()
o1 := make(chan error)
go func() {
err := s.EstablishConnection("tcp", r1)
o1 <- err
}()
go func() {
w1.Write(cl1RawBytes)
}()
// receive the first connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w1)
require.NoError(t, err)
recv <- buf
}()
// Get the first client pointer
time.Sleep(time.Millisecond * 50)
cl1, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.Equal(t, packets.Subscription{Filter: "a/b/c", Qos: 1}, sub)
cl1.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl1.State.Subscriptions.Add("d/e/f", packets.Subscription{Filter: "d/e/f", Qos: 0})
time.Sleep(time.Millisecond * 50)
// Make the second connection
r2, w2 := net.Pipe()
o2 := make(chan error)
go func() {
err := s.EstablishConnection("tcp", r2)
o2 <- err
}()
go func() {
x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:]
x[19] = '.' // differentiate username bytes in debugging
w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes)
}()
// receive the second connack
recv2 := make(chan []byte)
go func() {
buf, err := io.ReadAll(w2)
require.NoError(t, err)
recv2 <- buf
}()
// Capture first Client pointer
clp1, ok := s.Clients.Get("zen")
require.True(t, ok)
require.Empty(t, clp1.Properties.Username)
require.NotEmpty(t, clp1.State.Subscriptions.GetAll())
err1 := <-o1
require.Error(t, err1)
require.ErrorIs(t, err1, io.ErrClosedPipe)
// Capture second Client pointer
clp2, ok := s.Clients.Get("zen")
require.True(t, ok)
require.Equal(t, []byte(".ochi"), clp2.Properties.Username)
require.NotEmpty(t, clp2.State.Subscriptions.GetAll())
require.Empty(t, clp1.State.Subscriptions.GetAll())
w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
require.NoError(t, <-o2)
}
func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {