Compare commits

...

5 Commits

Author SHA1 Message Date
mochi-co
ef5dcf68d0 Use context to signal client open state 2023-05-06 11:49:02 +01:00
JB
6704cf7227 Add packet ID exhausted hook (#217) 2023-05-06 10:37:27 +01:00
thedevop
9233e6fd39 Expire session if SessionExpiryInterval is 0 (#216)
If SessionExpiryInterval was not set in CONNECT, SessionExpiryIntervalFlag is also not set. According to spec:
  If the Session Expiry Interval is absent the value 0 is used. If it is set to 0, or is absent, the Session ends when the Network Connection is closed.
2023-05-06 10:12:33 +01:00
ħþ
1ca65d9631 Update codes.go (#215)
fix typo
2023-05-06 10:02:25 +01:00
ħþ
33229da885 Update codes.go (#214)
Fix typo
2023-05-06 09:59:50 +01:00
5 changed files with 27 additions and 18 deletions

View File

@@ -7,6 +7,7 @@ package mqtt
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
@@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
@@ -144,7 +145,7 @@ type ClientState struct {
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
open context.Context // indicate that the client is open for packet exchange
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
@@ -158,6 +159,7 @@ func newClient(c net.Conn, o *ops) *Client {
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
open: context.Background(),
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
@@ -330,7 +332,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return nil
}
@@ -371,7 +373,12 @@ func (cl *Client) Stop(err error) {
close(cl.State.outbound)
}
atomic.StoreUint32(&cl.State.done, 1)
if cl.State.open != nil {
var cancel context.CancelFunc
cl.State.open, cancel = context.WithCancel(cl.State.open)
cancel()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
@@ -386,7 +393,7 @@ func (cl *Client) StopCause() error {
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return atomic.LoadUint32(&cl.State.done) == 1
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
@@ -548,7 +555,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
cl.Lock()
defer cl.Unlock()
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return ErrConnectionClosed
}

View File

@@ -5,6 +5,7 @@
package mqtt
import (
"context"
"errors"
"io"
"net"
@@ -114,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
@@ -466,7 +467,7 @@ func TestClientReadOK(t *testing.T) {
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
cl.State.open = nil
o := make(chan error)
go func() {
@@ -483,7 +484,7 @@ func TestClientStop(t *testing.T) {
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}

View File

@@ -21,7 +21,7 @@ func (c Code) Error() string {
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
// QosCodes indicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
@@ -120,7 +120,7 @@ var (
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}

View File

@@ -372,7 +372,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
}
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
@@ -826,6 +826,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if out.FixedHeader.Qos > 0 {
i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1]
if err != nil {
s.hooks.OnPacketIDExhausted(cl, pk)
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted")
return out, packets.ErrQuotaExceeded
}
@@ -846,7 +847,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
}
}
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Net.Conn == nil || cl.Closed() {
return out, packets.CodeDisconnect
}

View File

@@ -2476,7 +2476,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 0, s.loop.willDelayed.Len())
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
require.True(t, cl.Closed())
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
}
@@ -2806,7 +2806,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl0, _, _ := newTestClient()
cl0.ID = "c0"
cl0.State.disconnected = n - 10
cl0.State.done = 1
cl0.State.open = nil
cl0.Properties.ProtocolVersion = 5
cl0.Properties.Props.SessionExpiryInterval = 12
cl0.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2816,7 +2816,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl1, _, _ := newTestClient()
cl1.ID = "c1"
cl1.State.disconnected = n - 10
cl1.State.done = 1
cl1.State.open = nil
cl1.Properties.ProtocolVersion = 5
cl1.Properties.Props.SessionExpiryInterval = 8
cl1.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2826,7 +2826,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl2, _, _ := newTestClient()
cl2.ID = "c2"
cl2.State.disconnected = n - 10
cl2.State.done = 1
cl2.State.open = nil
cl2.Properties.ProtocolVersion = 5
cl2.Properties.Props.SessionExpiryInterval = 0
cl2.Properties.Props.SessionExpiryIntervalFlag = true