mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-27 20:42:19 +08:00
Compare commits
1 Commits
v2.6.7
...
use-contex
Author | SHA1 | Date | |
---|---|---|---|
![]() |
ef5dcf68d0 |
19
clients.go
19
clients.go
@@ -7,6 +7,7 @@ package mqtt
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
|||||||
defer cl.RUnlock()
|
defer cl.RUnlock()
|
||||||
clients := make([]*Client, 0, cl.Len())
|
clients := make([]*Client, 0, cl.Len())
|
||||||
for _, client := range cl.internal {
|
for _, client := range cl.internal {
|
||||||
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
|
if client.Net.Listener == id && !client.Closed() {
|
||||||
clients = append(clients, client)
|
clients = append(clients, client)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -144,7 +145,7 @@ type ClientState struct {
|
|||||||
endOnce sync.Once // only end once
|
endOnce sync.Once // only end once
|
||||||
isTakenOver uint32 // used to identify orphaned clients
|
isTakenOver uint32 // used to identify orphaned clients
|
||||||
packetID uint32 // the current highest packetID
|
packetID uint32 // the current highest packetID
|
||||||
done uint32 // atomic counter which indicates that the client has closed
|
open context.Context // indicate that the client is open for packet exchange
|
||||||
outboundQty int32 // number of messages currently in the outbound queue
|
outboundQty int32 // number of messages currently in the outbound queue
|
||||||
keepalive uint16 // the number of seconds the connection can wait
|
keepalive uint16 // the number of seconds the connection can wait
|
||||||
}
|
}
|
||||||
@@ -158,6 +159,7 @@ func newClient(c net.Conn, o *ops) *Client {
|
|||||||
Subscriptions: NewSubscriptions(),
|
Subscriptions: NewSubscriptions(),
|
||||||
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
||||||
keepalive: defaultKeepalive,
|
keepalive: defaultKeepalive,
|
||||||
|
open: context.Background(),
|
||||||
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
||||||
},
|
},
|
||||||
Properties: ClientProperties{
|
Properties: ClientProperties{
|
||||||
@@ -330,7 +332,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
if cl.Closed() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -371,7 +373,12 @@ func (cl *Client) Stop(err error) {
|
|||||||
close(cl.State.outbound)
|
close(cl.State.outbound)
|
||||||
}
|
}
|
||||||
|
|
||||||
atomic.StoreUint32(&cl.State.done, 1)
|
if cl.State.open != nil {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
cl.State.open, cancel = context.WithCancel(cl.State.open)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -386,7 +393,7 @@ func (cl *Client) StopCause() error {
|
|||||||
|
|
||||||
// Closed returns true if client connection is closed.
|
// Closed returns true if client connection is closed.
|
||||||
func (cl *Client) Closed() bool {
|
func (cl *Client) Closed() bool {
|
||||||
return atomic.LoadUint32(&cl.State.done) == 1
|
return cl.State.open == nil || cl.State.open.Err() != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||||
@@ -548,7 +555,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
|||||||
cl.Lock()
|
cl.Lock()
|
||||||
defer cl.Unlock()
|
defer cl.Unlock()
|
||||||
|
|
||||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
if cl.Closed() {
|
||||||
return ErrConnectionClosed
|
return ErrConnectionClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -5,6 +5,7 @@
|
|||||||
package mqtt
|
package mqtt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -114,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
|
|||||||
|
|
||||||
func TestClientsGetByListener(t *testing.T) {
|
func TestClientsGetByListener(t *testing.T) {
|
||||||
cl := NewClients()
|
cl := NewClients()
|
||||||
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
|
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
|
||||||
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
|
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
|
||||||
require.Contains(t, cl.internal, "t1")
|
require.Contains(t, cl.internal, "t1")
|
||||||
require.Contains(t, cl.internal, "t2")
|
require.Contains(t, cl.internal, "t2")
|
||||||
|
|
||||||
@@ -466,7 +467,7 @@ func TestClientReadOK(t *testing.T) {
|
|||||||
func TestClientReadDone(t *testing.T) {
|
func TestClientReadDone(t *testing.T) {
|
||||||
cl, _, _ := newTestClient()
|
cl, _, _ := newTestClient()
|
||||||
defer cl.Stop(errClientStop)
|
defer cl.Stop(errClientStop)
|
||||||
cl.State.done = 1
|
cl.State.open = nil
|
||||||
|
|
||||||
o := make(chan error)
|
o := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -483,7 +484,7 @@ func TestClientStop(t *testing.T) {
|
|||||||
cl.Stop(nil)
|
cl.Stop(nil)
|
||||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||||
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
||||||
require.Equal(t, uint32(1), cl.State.done)
|
require.True(t, cl.Closed())
|
||||||
require.Equal(t, nil, cl.StopCause())
|
require.Equal(t, nil, cl.StopCause())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -847,7 +847,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
|
if cl.Net.Conn == nil || cl.Closed() {
|
||||||
return out, packets.CodeDisconnect
|
return out, packets.CodeDisconnect
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -2476,7 +2476,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, 0, s.loop.willDelayed.Len())
|
require.Equal(t, 0, s.loop.willDelayed.Len())
|
||||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
|
require.True(t, cl.Closed())
|
||||||
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
|
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2806,7 +2806,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
|||||||
cl0, _, _ := newTestClient()
|
cl0, _, _ := newTestClient()
|
||||||
cl0.ID = "c0"
|
cl0.ID = "c0"
|
||||||
cl0.State.disconnected = n - 10
|
cl0.State.disconnected = n - 10
|
||||||
cl0.State.done = 1
|
cl0.State.open = nil
|
||||||
cl0.Properties.ProtocolVersion = 5
|
cl0.Properties.ProtocolVersion = 5
|
||||||
cl0.Properties.Props.SessionExpiryInterval = 12
|
cl0.Properties.Props.SessionExpiryInterval = 12
|
||||||
cl0.Properties.Props.SessionExpiryIntervalFlag = true
|
cl0.Properties.Props.SessionExpiryIntervalFlag = true
|
||||||
@@ -2816,7 +2816,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
|||||||
cl1, _, _ := newTestClient()
|
cl1, _, _ := newTestClient()
|
||||||
cl1.ID = "c1"
|
cl1.ID = "c1"
|
||||||
cl1.State.disconnected = n - 10
|
cl1.State.disconnected = n - 10
|
||||||
cl1.State.done = 1
|
cl1.State.open = nil
|
||||||
cl1.Properties.ProtocolVersion = 5
|
cl1.Properties.ProtocolVersion = 5
|
||||||
cl1.Properties.Props.SessionExpiryInterval = 8
|
cl1.Properties.Props.SessionExpiryInterval = 8
|
||||||
cl1.Properties.Props.SessionExpiryIntervalFlag = true
|
cl1.Properties.Props.SessionExpiryIntervalFlag = true
|
||||||
@@ -2826,7 +2826,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
|||||||
cl2, _, _ := newTestClient()
|
cl2, _, _ := newTestClient()
|
||||||
cl2.ID = "c2"
|
cl2.ID = "c2"
|
||||||
cl2.State.disconnected = n - 10
|
cl2.State.disconnected = n - 10
|
||||||
cl2.State.done = 1
|
cl2.State.open = nil
|
||||||
cl2.Properties.ProtocolVersion = 5
|
cl2.Properties.ProtocolVersion = 5
|
||||||
cl2.Properties.Props.SessionExpiryInterval = 0
|
cl2.Properties.Props.SessionExpiryInterval = 0
|
||||||
cl2.Properties.Props.SessionExpiryIntervalFlag = true
|
cl2.Properties.Props.SessionExpiryIntervalFlag = true
|
||||||
|
Reference in New Issue
Block a user