Compare commits

...

5 Commits

Author SHA1 Message Date
mochi-co
d3785c2717 update server version 2023-05-08 11:43:46 +01:00
thedevop
52a347169a Use context to exit WriteLoop (#222)
* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Fix misspelling
2023-05-08 11:30:44 +01:00
mochi-co
797d75cb34 update server version 2023-05-06 14:32:42 +01:00
JB
5225a357e5 refactor server keepalive for hook access (#220) 2023-05-06 14:11:54 +01:00
JB
a734a0dc73 Use context to signal client open state (#218) 2023-05-06 11:55:40 +01:00
8 changed files with 126 additions and 88 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
cmd/mqtt
.DS_Store
*.db
.idea

View File

@@ -7,6 +7,7 @@ package mqtt
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
@@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
@@ -135,29 +136,34 @@ type Will struct {
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
cancelOpen context.CancelFunc // cancel function for open context
outboundQty int32 // number of messages currently in the outbound queue
Keepalive uint16 // the number of seconds the connection can wait
ServerKeepalive bool // keepalive was set by the server
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
ctx, cancel := context.WithCancel(context.Background())
cl := &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
open: ctx,
cancelOpen: cancel,
Keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
@@ -177,18 +183,21 @@ func newClient(c net.Conn, o *ops) *Client {
}
}
cl.refreshDeadline(cl.State.keepalive)
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for pk := range cl.State.outbound {
if err := cl.WritePacket(*pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
for {
select {
case pk := <-cl.State.outbound:
if err := cl.WritePacket(*pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
}
atomic.AddInt32(&cl.State.outboundQty, -1)
case <-cl.State.open.Done():
return
}
atomic.AddInt32(&cl.State.outboundQty, -1)
}
}
@@ -201,9 +210,9 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
@@ -212,11 +221,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
@@ -234,8 +238,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
cl.refreshDeadline(cl.State.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
@@ -330,11 +332,11 @@ func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return nil
}
cl.refreshDeadline(cl.State.keepalive)
cl.refreshDeadline(cl.State.Keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
@@ -356,8 +358,6 @@ func (cl *Client) Read(packetHandler ReadFn) error {
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
cl.State.endOnce.Do(func() {
cl.Lock()
defer cl.Unlock()
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
@@ -367,11 +367,10 @@ func (cl *Client) Stop(err error) {
cl.State.stopCause.Store(err)
}
if cl.State.outbound != nil {
close(cl.State.outbound)
if cl.State.cancelOpen != nil {
cl.State.cancelOpen()
}
atomic.StoreUint32(&cl.State.done, 1)
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
@@ -386,7 +385,7 @@ func (cl *Client) StopCause() error {
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return atomic.LoadUint32(&cl.State.done) == 1
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
@@ -548,7 +547,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
cl.Lock()
defer cl.Unlock()
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return ErrConnectionClosed
}

View File

@@ -5,6 +5,7 @@
package mqtt
import (
"context"
"errors"
"io"
"net"
@@ -114,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
@@ -132,7 +133,7 @@ func TestNewClient(t *testing.T) {
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
@@ -164,7 +165,7 @@ func TestClientParseConnect(t *testing.T) {
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
@@ -466,7 +467,7 @@ func TestClientReadOK(t *testing.T) {
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
cl.State.cancelOpen()
o := make(chan error)
go func() {
@@ -483,7 +484,7 @@ func TestClientStop(t *testing.T) {
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}

View File

@@ -30,14 +30,14 @@ func main() {
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
err = server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}

View File

@@ -26,7 +26,6 @@ func main() {
}()
server := mqtt.New(nil)
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
@@ -61,6 +60,7 @@ func (h *pahoAuthHook) ID() string {
func (h *pahoAuthHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnConnect,
mqtt.OnACLCheck,
}, []byte{b})
}
@@ -72,3 +72,11 @@ func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet)
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return topic != "test/nosubscribe"
}
func (h *pahoAuthHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
// Handle paho test_server_keep_alive
if pk.Connect.Keepalive == 120 && pk.Connect.Clean {
cl.State.Keepalive = 60
cl.State.ServerKeepalive = true
}
}

View File

@@ -82,6 +82,7 @@ const (
TConnackAcceptedAdjustedExpiryInterval
TConnackMinMqtt5
TConnackMinCleanMqtt5
TConnackServerKeepalive
TConnackInvalidMinMqtt5
TConnackBadProtocolVersion
TConnackProtocolViolationNoSession
@@ -1085,25 +1086,22 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted, no session, adjusted expiry interval mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 11, // fixed header
Connack << 4, 8, // fixed header
0, // Session present
CodeSuccess.Code,
8, // length
5, // length
17, 0, 0, 0, 120, // Session Expiry Interval (17)
19, 0, 10, // Server Keep Alive (19)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 11,
Remaining: 8,
},
ReasonCode: CodeSuccess.Code,
Properties: Properties{
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
},
},
},
@@ -1190,28 +1188,25 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 16, // fixed header
Connack << 4, 13, // fixed header
1, // existing session
CodeSuccess.Code,
13, // Properties length
10, // Properties length
18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18)
19, 0, 20, // Server Keep Alive (19)
36, 1, // Maximum Qos (36)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 16,
Remaining: 13,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
},
},
},
@@ -1220,11 +1215,10 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5b",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
Connack << 4, 3, // fixed header
0, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
0, // Properties length
},
Packet: &Packet{
ProtocolVersion: 5,
@@ -1234,6 +1228,27 @@ var TPacketData = map[byte]TPacketCases{
},
SessionPresent: false,
ReasonCode: CodeSuccess.Code,
},
},
{
Case: TConnackServerKeepalive,
Desc: "server set keepalive",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
1, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 6,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
@@ -1245,26 +1260,23 @@ var TPacketData = map[byte]TPacketCases{
Desc: "failure min properties mqtt5",
Primary: true,
RawBytes: append([]byte{
Connack << 4, 26, // fixed header
Connack << 4, 23, // fixed header
0, // No existing session
ErrUnspecifiedError.Code,
// Properties
23, // length
19, 0, 20, // Server Keep Alive (19)
20, // length
31, 0, 17, // Reason String (31)
}, []byte(ErrUnspecifiedError.Reason)...),
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 25,
Remaining: 23,
},
SessionPresent: false,
ReasonCode: ErrUnspecifiedError.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
ReasonString: ErrUnspecifiedError.Reason,
ReasonString: ErrUnspecifiedError.Reason,
},
},
},

View File

@@ -26,8 +26,8 @@ import (
)
const (
Version = "2.2.8" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
Version = "2.2.10" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
)
var (
@@ -43,7 +43,6 @@ var (
WildcardSubAvailable: 1, // wildcard subscriptions are available
SubIDAvailable: 1, // subscription identifiers are available
SharedSubAvailable: 1, // shared subscriptions are available
ServerKeepAlive: 10, // default keepalive for clients
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
}
@@ -61,7 +60,6 @@ type Capabilities struct {
maximumPacketID uint32 // unexported, used for testing only
ReceiveMaximum uint16
TopicAliasMaximum uint16
ServerKeepAlive uint16
SharedSubAvailable byte
MinimumProtocolVersion byte
Compatibilities Compatibilities
@@ -331,6 +329,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
}
s.hooks.OnConnect(cl, pk)
cl.refreshDeadline(cl.State.Keepalive)
if !s.hooks.OnConnectAuthenticate(cl, pk) { // [MQTT-3.1.4-2]
err := s.sendConnack(cl, packets.ErrBadUsernameOrPassword, false)
@@ -498,9 +497,12 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
// sendConnack returns a Connack packet to a client.
func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) error {
properties := packets.Properties{
ServerKeepAlive: s.Options.Capabilities.ServerKeepAlive, // [MQTT-3.1.2-21]
ServerKeepAliveFlag: true,
ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum, // 3.2.2.3.3 Receive Maximum
ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum, // 3.2.2.3.3 Receive Maximum
}
if cl.State.ServerKeepalive { // You can set this dynamically using the OnConnect hook.
properties.ServerKeepAlive = cl.State.Keepalive // [MQTT-3.1.2-21]
properties.ServerKeepAliveFlag = true
}
if reason.Code >= packets.ErrUnspecifiedError.Code {
@@ -847,7 +849,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
}
}
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Net.Conn == nil || cl.Closed() {
return out, packets.CodeDisconnect
}

View File

@@ -127,7 +127,7 @@ func TestServerNewClient(t *testing.T) {
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
@@ -821,7 +821,6 @@ func TestServerSendConnack(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20
s.Options.Capabilities.MaximumQos = 1
cl.Properties.Props = packets.Properties{
AssignedClientID: "mochi",
@@ -841,7 +840,6 @@ func TestServerSendConnackFailureReason(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20
go func() {
err := s.sendConnack(cl, packets.ErrUnspecifiedError, true)
require.NoError(t, err)
@@ -853,6 +851,23 @@ func TestServerSendConnackFailureReason(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackInvalidMinMqtt5).RawBytes, buf)
}
func TestServerSendConnackWithServerKeepalive(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Keepalive = 10
cl.State.ServerKeepalive = true
go func() {
err := s.sendConnack(cl, packets.CodeSuccess, true)
require.NoError(t, err)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackServerKeepalive).RawBytes, buf)
}
func TestServerValidateConnect(t *testing.T) {
packet := *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet
invalidBitPacket := packet
@@ -2476,7 +2491,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 0, s.loop.willDelayed.Len())
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
require.True(t, cl.Closed())
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
}
@@ -2806,7 +2821,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl0, _, _ := newTestClient()
cl0.ID = "c0"
cl0.State.disconnected = n - 10
cl0.State.done = 1
cl0.State.cancelOpen()
cl0.Properties.ProtocolVersion = 5
cl0.Properties.Props.SessionExpiryInterval = 12
cl0.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2816,7 +2831,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl1, _, _ := newTestClient()
cl1.ID = "c1"
cl1.State.disconnected = n - 10
cl1.State.done = 1
cl1.State.cancelOpen()
cl1.Properties.ProtocolVersion = 5
cl1.Properties.Props.SessionExpiryInterval = 8
cl1.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2826,7 +2841,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl2, _, _ := newTestClient()
cl2.ID = "c2"
cl2.State.disconnected = n - 10
cl2.State.done = 1
cl2.State.cancelOpen()
cl2.Properties.ProtocolVersion = 5
cl2.Properties.Props.SessionExpiryInterval = 0
cl2.Properties.Props.SessionExpiryIntervalFlag = true