mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-26 20:21:12 +08:00
Use context to signal client open state (#218)
This commit is contained in:
19
clients.go
19
clients.go
@@ -7,6 +7,7 @@ package mqtt
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
defer cl.RUnlock()
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
for _, client := range cl.internal {
|
||||
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
|
||||
if client.Net.Listener == id && !client.Closed() {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
@@ -144,7 +145,7 @@ type ClientState struct {
|
||||
endOnce sync.Once // only end once
|
||||
isTakenOver uint32 // used to identify orphaned clients
|
||||
packetID uint32 // the current highest packetID
|
||||
done uint32 // atomic counter which indicates that the client has closed
|
||||
open context.Context // indicate that the client is open for packet exchange
|
||||
outboundQty int32 // number of messages currently in the outbound queue
|
||||
keepalive uint16 // the number of seconds the connection can wait
|
||||
}
|
||||
@@ -158,6 +159,7 @@ func newClient(c net.Conn, o *ops) *Client {
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
||||
keepalive: defaultKeepalive,
|
||||
open: context.Background(),
|
||||
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
@@ -330,7 +332,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
var err error
|
||||
|
||||
for {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
if cl.Closed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -371,7 +373,12 @@ func (cl *Client) Stop(err error) {
|
||||
close(cl.State.outbound)
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&cl.State.done, 1)
|
||||
if cl.State.open != nil {
|
||||
var cancel context.CancelFunc
|
||||
cl.State.open, cancel = context.WithCancel(cl.State.open)
|
||||
cancel()
|
||||
}
|
||||
|
||||
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
||||
})
|
||||
}
|
||||
@@ -386,7 +393,7 @@ func (cl *Client) StopCause() error {
|
||||
|
||||
// Closed returns true if client connection is closed.
|
||||
func (cl *Client) Closed() bool {
|
||||
return atomic.LoadUint32(&cl.State.done) == 1
|
||||
return cl.State.open == nil || cl.State.open.Err() != nil
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
@@ -548,7 +555,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
if cl.Closed() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
|
@@ -5,6 +5,7 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@@ -114,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
|
||||
|
||||
func TestClientsGetByListener(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
|
||||
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
@@ -466,7 +467,7 @@ func TestClientReadOK(t *testing.T) {
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
cl.State.done = 1
|
||||
cl.State.open = nil
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
@@ -483,7 +484,7 @@ func TestClientStop(t *testing.T) {
|
||||
cl.Stop(nil)
|
||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
||||
require.Equal(t, uint32(1), cl.State.done)
|
||||
require.True(t, cl.Closed())
|
||||
require.Equal(t, nil, cl.StopCause())
|
||||
}
|
||||
|
||||
|
@@ -847,7 +847,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
||||
}
|
||||
}
|
||||
|
||||
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
if cl.Net.Conn == nil || cl.Closed() {
|
||||
return out, packets.CodeDisconnect
|
||||
}
|
||||
|
||||
|
@@ -2476,7 +2476,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 0, s.loop.willDelayed.Len())
|
||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
|
||||
require.True(t, cl.Closed())
|
||||
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
|
||||
}
|
||||
|
||||
@@ -2806,7 +2806,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
||||
cl0, _, _ := newTestClient()
|
||||
cl0.ID = "c0"
|
||||
cl0.State.disconnected = n - 10
|
||||
cl0.State.done = 1
|
||||
cl0.State.open = nil
|
||||
cl0.Properties.ProtocolVersion = 5
|
||||
cl0.Properties.Props.SessionExpiryInterval = 12
|
||||
cl0.Properties.Props.SessionExpiryIntervalFlag = true
|
||||
@@ -2816,7 +2816,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
||||
cl1, _, _ := newTestClient()
|
||||
cl1.ID = "c1"
|
||||
cl1.State.disconnected = n - 10
|
||||
cl1.State.done = 1
|
||||
cl1.State.open = nil
|
||||
cl1.Properties.ProtocolVersion = 5
|
||||
cl1.Properties.Props.SessionExpiryInterval = 8
|
||||
cl1.Properties.Props.SessionExpiryIntervalFlag = true
|
||||
@@ -2826,7 +2826,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
||||
cl2, _, _ := newTestClient()
|
||||
cl2.ID = "c2"
|
||||
cl2.State.disconnected = n - 10
|
||||
cl2.State.done = 1
|
||||
cl2.State.open = nil
|
||||
cl2.Properties.ProtocolVersion = 5
|
||||
cl2.Properties.Props.SessionExpiryInterval = 0
|
||||
cl2.Properties.Props.SessionExpiryIntervalFlag = true
|
||||
|
Reference in New Issue
Block a user