Use context to signal client open state (#218)

This commit is contained in:
JB
2023-05-06 11:55:40 +01:00
committed by GitHub
parent 6704cf7227
commit a734a0dc73
4 changed files with 23 additions and 15 deletions

View File

@@ -7,6 +7,7 @@ package mqtt
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
@@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
@@ -144,7 +145,7 @@ type ClientState struct {
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
open context.Context // indicate that the client is open for packet exchange
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
@@ -158,6 +159,7 @@ func newClient(c net.Conn, o *ops) *Client {
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
open: context.Background(),
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
@@ -330,7 +332,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return nil
}
@@ -371,7 +373,12 @@ func (cl *Client) Stop(err error) {
close(cl.State.outbound)
}
atomic.StoreUint32(&cl.State.done, 1)
if cl.State.open != nil {
var cancel context.CancelFunc
cl.State.open, cancel = context.WithCancel(cl.State.open)
cancel()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
@@ -386,7 +393,7 @@ func (cl *Client) StopCause() error {
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return atomic.LoadUint32(&cl.State.done) == 1
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
@@ -548,7 +555,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
cl.Lock()
defer cl.Unlock()
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return ErrConnectionClosed
}

View File

@@ -5,6 +5,7 @@
package mqtt
import (
"context"
"errors"
"io"
"net"
@@ -114,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
@@ -466,7 +467,7 @@ func TestClientReadOK(t *testing.T) {
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
cl.State.open = nil
o := make(chan error)
go func() {
@@ -483,7 +484,7 @@ func TestClientStop(t *testing.T) {
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}

View File

@@ -847,7 +847,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
}
}
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Net.Conn == nil || cl.Closed() {
return out, packets.CodeDisconnect
}

View File

@@ -2476,7 +2476,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 0, s.loop.willDelayed.Len())
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
require.True(t, cl.Closed())
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
}
@@ -2806,7 +2806,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl0, _, _ := newTestClient()
cl0.ID = "c0"
cl0.State.disconnected = n - 10
cl0.State.done = 1
cl0.State.open = nil
cl0.Properties.ProtocolVersion = 5
cl0.Properties.Props.SessionExpiryInterval = 12
cl0.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2816,7 +2816,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl1, _, _ := newTestClient()
cl1.ID = "c1"
cl1.State.disconnected = n - 10
cl1.State.done = 1
cl1.State.open = nil
cl1.Properties.ProtocolVersion = 5
cl1.Properties.Props.SessionExpiryInterval = 8
cl1.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2826,7 +2826,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl2, _, _ := newTestClient()
cl2.ID = "c2"
cl2.State.disconnected = n - 10
cl2.State.done = 1
cl2.State.open = nil
cl2.Properties.ProtocolVersion = 5
cl2.Properties.Props.SessionExpiryInterval = 0
cl2.Properties.Props.SessionExpiryIntervalFlag = true