mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-26 20:21:12 +08:00
Correctly identify and clean taken-over sessions (#180)
This commit is contained in:
@@ -142,6 +142,7 @@ type ClientState struct {
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
outbound chan *packets.Packet // queue for pending outbound packets
|
||||
endOnce sync.Once // only end once
|
||||
isTakenOver uint32 // used to identify orphaned clients
|
||||
packetID uint32 // the current highest packetID
|
||||
done uint32 // atomic counter which indicates that the client has closed
|
||||
outboundQty int32 // number of messages currently in the outbound queue
|
||||
|
12
inflight.go
12
inflight.go
@@ -58,6 +58,18 @@ func (i *Inflight) Len() int {
|
||||
return len(i.internal)
|
||||
}
|
||||
|
||||
// Clone returns a new instance of Inflight with the same message data.
|
||||
// This is used when transferring inflights from a taken-over session.
|
||||
func (i *Inflight) Clone() *Inflight {
|
||||
c := NewInflights()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
for k, v := range i.internal {
|
||||
c.internal[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// GetAll returns all the inflight messages.
|
||||
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
|
||||
i.RLock()
|
||||
|
@@ -61,6 +61,16 @@ func TestInflightLen(t *testing.T) {
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestInflightClone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
cloned := cl.State.Inflight.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.NotSame(t, cloned, cl.State.Inflight)
|
||||
}
|
||||
|
||||
func TestInflightDelete(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
|
@@ -250,26 +250,26 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "mqtt v3.1.1",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Connect << 4, 16, // Fixed header
|
||||
Connect << 4, 15, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 60, // Keepalive
|
||||
0, 4, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', '3', // Client ID "zen"
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connect,
|
||||
Remaining: 16,
|
||||
Remaining: 15,
|
||||
},
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: false,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "zen3",
|
||||
ClientIdentifier: "zen",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -426,9 +426,9 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Connect << 4, 28, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
194, // Packet Flags
|
||||
0, 20, // Keepalive
|
||||
4, // Protocol Version
|
||||
0 | 1<<6 | 1<<7, // Packet Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 5, // Username MSB+LSB
|
||||
@@ -444,7 +444,7 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: true,
|
||||
Clean: false,
|
||||
Keepalive: 20,
|
||||
ClientIdentifier: "zen",
|
||||
UsernameFlag: true,
|
||||
|
40
server.go
40
server.go
@@ -26,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "2.2.3" // the current server version.
|
||||
Version = "2.2.4" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
)
|
||||
|
||||
@@ -353,9 +353,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
|
||||
if err != nil {
|
||||
s.sendLWT(cl)
|
||||
cl.Stop(err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
} else {
|
||||
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
|
||||
}
|
||||
|
||||
@@ -365,9 +363,11 @@ func (s *Server) attachClient(cl *Client, listener string) error {
|
||||
close(cl.State.outbound)
|
||||
|
||||
if expire {
|
||||
s.UnsubscribeClient(cl)
|
||||
cl.ClearInflights(math.MaxInt64, 0)
|
||||
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
||||
s.UnsubscribeClient(cl)
|
||||
if atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
|
||||
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -440,17 +440,17 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
|
||||
// session is abandoned.
|
||||
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
||||
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
|
||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
return false // [MQTT-3.2.2-3]
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&existing.State.isTakenOver, 1)
|
||||
|
||||
if existing.State.Inflight.Len() > 0 {
|
||||
existing.State.Inflight.Lock()
|
||||
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
|
||||
existing.State.Inflight.Unlock()
|
||||
cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
|
||||
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.capabilities.ReceiveMaximum != 0 {
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
@@ -465,6 +465,15 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
||||
cl.State.Subscriptions.Add(sub.Filter, sub)
|
||||
}
|
||||
|
||||
// Clean the state of the existing client to prevent sequential take-overs
|
||||
// from increasing memory usage by inflights + subs * client-id.
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
s.Log.Debug().Str("client", cl.ID).
|
||||
Str("old_remote", existing.Net.Remote).
|
||||
Str("new_remote", cl.Net.Remote).
|
||||
Msg("session taken over")
|
||||
|
||||
return true // [MQTT-3.2.2-3]
|
||||
}
|
||||
|
||||
@@ -1087,8 +1096,15 @@ func (s *Server) UnsubscribeClient(cl *Client) {
|
||||
i := 0
|
||||
filterMap := cl.State.Subscriptions.GetAll()
|
||||
filters := make([]packets.Subscription, len(filterMap))
|
||||
for k, v := range filterMap {
|
||||
for k := range filterMap {
|
||||
cl.State.Subscriptions.Delete(k)
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&cl.State.isTakenOver) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range filterMap {
|
||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||
atomic.AddInt64(&s.Info.Subscriptions, -1)
|
||||
}
|
||||
|
112
server_test.go
112
server_test.go
@@ -48,6 +48,23 @@ func (h *AllowHook) Provides(b byte) bool {
|
||||
func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true }
|
||||
func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true }
|
||||
|
||||
type DelayHook struct {
|
||||
HookBase
|
||||
DisconnectDelay time.Duration
|
||||
}
|
||||
|
||||
func (h *DelayHook) ID() string {
|
||||
return "delay-hook"
|
||||
}
|
||||
|
||||
func (h *DelayHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{OnDisconnect}, []byte{b})
|
||||
}
|
||||
|
||||
func (h *DelayHook) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
time.Sleep(h.DisconnectDelay)
|
||||
}
|
||||
|
||||
func newServer() *Server {
|
||||
cc := *DefaultServerCapabilities
|
||||
cc.MaximumMessageExpiryInterval = 0
|
||||
@@ -401,6 +418,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
|
||||
|
||||
cl, r0, _ := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.Properties.Username = []byte("mochi")
|
||||
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
|
||||
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
|
||||
cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
@@ -456,9 +474,99 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, clw.State.Subscriptions)
|
||||
|
||||
sub, ok := cl.State.Subscriptions.Get("a/b/c")
|
||||
// Prevent sequential takeover memory-bloom.
|
||||
require.Empty(t, cl.State.Subscriptions.GetAll())
|
||||
}
|
||||
|
||||
// See https://github.com/mochi-co/mqtt/issues/173
|
||||
func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) {
|
||||
s := newServer()
|
||||
d := new(DelayHook)
|
||||
d.DisconnectDelay = time.Millisecond * 200
|
||||
s.AddHook(d, nil)
|
||||
defer s.Close()
|
||||
|
||||
// Clean session, 0 session expiry interval
|
||||
cl1RawBytes := []byte{
|
||||
packets.Connect << 4, 21, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
5, // Protocol Version
|
||||
1 << 1, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
5, // Properties length
|
||||
17, 0, 0, 0, 0, // Session Expiry Interval (17)
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
}
|
||||
|
||||
// Make first connection
|
||||
r1, w1 := net.Pipe()
|
||||
o1 := make(chan error)
|
||||
go func() {
|
||||
err := s.EstablishConnection("tcp", r1)
|
||||
o1 <- err
|
||||
}()
|
||||
go func() {
|
||||
w1.Write(cl1RawBytes)
|
||||
}()
|
||||
|
||||
// receive the first connack
|
||||
recv := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(w1)
|
||||
require.NoError(t, err)
|
||||
recv <- buf
|
||||
}()
|
||||
|
||||
// Get the first client pointer
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
cl1, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).Packet.Connect.ClientIdentifier)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, packets.Subscription{Filter: "a/b/c", Qos: 1}, sub)
|
||||
cl1.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
|
||||
cl1.State.Subscriptions.Add("d/e/f", packets.Subscription{Filter: "d/e/f", Qos: 0})
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
|
||||
// Make the second connection
|
||||
r2, w2 := net.Pipe()
|
||||
o2 := make(chan error)
|
||||
go func() {
|
||||
err := s.EstablishConnection("tcp", r2)
|
||||
o2 <- err
|
||||
}()
|
||||
go func() {
|
||||
x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:]
|
||||
x[19] = '.' // differentiate username bytes in debugging
|
||||
w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes)
|
||||
}()
|
||||
|
||||
// receive the second connack
|
||||
recv2 := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(w2)
|
||||
require.NoError(t, err)
|
||||
recv2 <- buf
|
||||
}()
|
||||
|
||||
// Capture first Client pointer
|
||||
clp1, ok := s.Clients.Get("zen")
|
||||
require.True(t, ok)
|
||||
require.Empty(t, clp1.Properties.Username)
|
||||
require.NotEmpty(t, clp1.State.Subscriptions.GetAll())
|
||||
|
||||
err1 := <-o1
|
||||
require.Error(t, err1)
|
||||
require.ErrorIs(t, err1, io.ErrClosedPipe)
|
||||
|
||||
// Capture second Client pointer
|
||||
clp2, ok := s.Clients.Get("zen")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, []byte(".ochi"), clp2.Properties.Username)
|
||||
require.NotEmpty(t, clp2.State.Subscriptions.GetAll())
|
||||
require.Empty(t, clp1.State.Subscriptions.GetAll())
|
||||
|
||||
w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
|
||||
require.NoError(t, <-o2)
|
||||
}
|
||||
|
||||
func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {
|
||||
|
Reference in New Issue
Block a user