From fe5d9ffa6107f53ff68c5fc9b134a25ac52f159c Mon Sep 17 00:00:00 2001 From: mochi-co Date: Mon, 12 Dec 2022 11:37:19 +0000 Subject: [PATCH] Simplify Client construction, add NewClient method to Server, add Publish convenience method --- clients.go | 53 ++------ clients_test.go | 100 +++++++-------- examples/hooks/main.go | 23 +++- inflight_test.go | 12 +- server.go | 73 ++++++++--- server_test.go | 278 +++++++++++++++++++++++++---------------- 6 files changed, 310 insertions(+), 229 deletions(-) diff --git a/clients.go b/clients.go index ff0ed9f..a34f018 100644 --- a/clients.go +++ b/clients.go @@ -146,14 +146,10 @@ type ClientState struct { keepalive uint16 // the number of seconds the connection can wait } -// NewClient returns a new instance of Client. -func NewClient(c net.Conn, o *ops) *Client { +// newClient returns a new instance of Client. This is almost exclusively used by Server +// for creating new clients, but it lives here because it's not dependent. +func newClient(c net.Conn, o *ops) *Client { cl := &Client{ - Net: ClientConnection{ - conn: c, - bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)), - Remote: c.RemoteAddr().String(), - }, State: ClientState{ Inflight: NewInflights(), Subscriptions: NewSubscriptions(), @@ -166,46 +162,19 @@ func NewClient(c net.Conn, o *ops) *Client { ops: o, } + if c != nil { + cl.Net = ClientConnection{ + conn: c, + bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)), + Remote: c.RemoteAddr().String(), + } + } + cl.refreshDeadline(cl.State.keepalive) return cl } -// NewInlineClient returns a client used when publishing from the embedding system. -func NewInlineClient(id, remote string) *Client { - return &Client{ - ID: id, - Net: ClientConnection{ - Remote: remote, - Inline: true, - }, - State: ClientState{ - Inflight: NewInflights(), - Subscriptions: NewSubscriptions(), - TopicAliases: NewTopicAliases(0), - }, - Properties: ClientProperties{ - ProtocolVersion: defaultClientProtocolVersion, // default protocol version - }, - } -} - -// newClientStub returns an instance of Client with minimal initializations, such as -// restoring client data from a db. In particular, the client is marked as offline (done). -func newClientStub() *Client { - return &Client{ - State: ClientState{ - Inflight: NewInflights(), - Subscriptions: NewSubscriptions(), - TopicAliases: NewTopicAliases(0), - done: 1, - }, - Properties: ClientProperties{ - ProtocolVersion: defaultClientProtocolVersion, // default protocol version - }, - } -} - // ParseConnect parses the connect parameters and properties for a client. func (cl *Client) ParseConnect(lid string, pk packets.Packet) { cl.Net.Listener = lid diff --git a/clients_test.go b/clients_test.go index f0cbfdd..c96e246 100644 --- a/clients_test.go +++ b/clients_test.go @@ -22,10 +22,10 @@ const pkInfo = "packet type %v, %s" var errClientStop = errors.New("test stop") -func newClient() (cl *Client, r net.Conn, w net.Conn) { +func newTestClient() (cl *Client, r net.Conn, w net.Conn) { r, w = net.Pipe() - cl = NewClient(w, &ops{ + cl = newClient(w, &ops{ info: new(system.Info), hooks: new(Hooks), log: &logger, @@ -119,34 +119,21 @@ func TestClientsGetByListener(t *testing.T) { } func TestNewClient(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() require.NotNil(t, cl) require.NotNil(t, cl.State.Inflight.internal) require.NotNil(t, cl.State.Subscriptions) - require.Nil(t, cl.StopCause()) -} - -func TestNewClientStub(t *testing.T) { - cl := newClientStub() - require.NotNil(t, cl) - require.NotNil(t, cl.State.Inflight.internal) - require.NotNil(t, cl.State.Subscriptions) - require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done)) -} - -func TestNewInlineClient(t *testing.T) { - cl := NewInlineClient("inline", "local") - require.NotNil(t, cl) - require.NotNil(t, cl.State.Inflight.internal) - require.NotNil(t, cl.State.Subscriptions) - require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done)) - require.Equal(t, "inline", cl.ID) - require.Equal(t, "local", cl.Net.Remote) + require.NotNil(t, cl.State.TopicAliases) + require.Equal(t, defaultKeepalive, cl.State.keepalive) + require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) + require.NotNil(t, cl.Net.conn) + require.NotNil(t, cl.Net.bconn) + require.False(t, cl.Net.Inline) } func TestClientParseConnect(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() pk := packets.Packet{ ProtocolVersion: 4, @@ -183,7 +170,7 @@ func TestClientParseConnect(t *testing.T) { } func TestClientParseConnectOverrideWillDelay(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() pk := packets.Packet{ ProtocolVersion: 4, @@ -208,13 +195,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) { } func TestClientParseConnectNoID(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.ParseConnect("tcp1", packets.Packet{}) require.NotEmpty(t, cl.ID) } func TestClientNextPacketID(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() i, err := cl.NextPacketID() require.NoError(t, err) @@ -226,7 +213,7 @@ func TestClientNextPacketID(t *testing.T) { } func TestClientNextPacketIDInUse(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() // skip over 2 cl.State.Inflight.Set(packets.Packet{PacketID: 2}) @@ -249,7 +236,7 @@ func TestClientNextPacketIDInUse(t *testing.T) { } func TestClientNextPacketIDExhausted(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() for i := 0; i <= 65535; i++ { cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) } @@ -261,7 +248,7 @@ func TestClientNextPacketIDExhausted(t *testing.T) { } func TestClientNextPacketIDOverflow(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.packetID = uint32(65534) @@ -275,7 +262,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) { } func TestClientClearInflights(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() n := time.Now().Unix() cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) @@ -291,7 +278,7 @@ func TestClientClearInflights(t *testing.T) { func TestClientResendInflightMessages(t *testing.T) { pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback) - cl, r, w := newClient() + cl, r, w := newTestClient() cl.State.Inflight.Set(*pk1.Packet) require.Equal(t, 1, cl.State.Inflight.Len()) @@ -311,7 +298,7 @@ func TestClientResendInflightMessages(t *testing.T) { func TestClientResendInflightMessagesWriteFailure(t *testing.T) { pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) - cl, r, _ := newClient() + cl, r, _ := newTestClient() r.Close() cl.State.Inflight.Set(*pk1.Packet) @@ -323,19 +310,19 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) { } func TestClientResendInflightMessagesNoMessages(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := cl.ResendInflightMessages(true) require.NoError(t, err) } func TestClientRefreshDeadline(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.refreshDeadline(10) require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline? } func TestClientReadFixedHeader(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { @@ -350,7 +337,7 @@ func TestClientReadFixedHeader(t *testing.T) { } func TestClientReadFixedHeaderDecodeError(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { @@ -364,7 +351,7 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) { } func TestClientReadFixedHeaderReadEOF(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { @@ -378,7 +365,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) { } func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { @@ -392,7 +379,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { } func TestClientReadOK(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { r.Write([]byte{ @@ -446,7 +433,7 @@ func TestClientReadOK(t *testing.T) { } func TestClientReadDone(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() defer cl.Stop(errClientStop) cl.State.done = 1 @@ -461,15 +448,16 @@ func TestClientReadDone(t *testing.T) { } func TestClientStop(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Stop(nil) require.Equal(t, nil, cl.State.stopCause.Load()) require.Equal(t, time.Now().Unix(), cl.State.disconnected) require.Equal(t, uint32(1), cl.State.done) + require.Equal(t, nil, cl.StopCause()) } func TestClientReadFixedHeaderError(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { r.Write([]byte{ @@ -486,7 +474,7 @@ func TestClientReadFixedHeaderError(t *testing.T) { } func TestClientReadReadHandlerErr(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { r.Write([]byte{ @@ -506,7 +494,7 @@ func TestClientReadReadHandlerErr(t *testing.T) { } func TestClientReadReadPacketOK(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { r.Write([]byte{ @@ -538,7 +526,7 @@ func TestClientReadReadPacketOK(t *testing.T) { } func TestClientReadPacket(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) for _, tx := range pkTable { @@ -571,9 +559,17 @@ func TestClientReadPacket(t *testing.T) { } } +func TestClientReadPacketInvalidTypeError(t *testing.T) { + cl, _, _ := newTestClient() + cl.Net.conn.Close() + _, err := cl.ReadPacket(&packets.FixedHeader{}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid packet type") +} + func TestClientWritePacket(t *testing.T) { for _, tt := range pkTable { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion @@ -613,7 +609,7 @@ func TestClientWritePacket(t *testing.T) { } func TestWriteClientOversizePacket(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Properties.Props.MaximumPacketSize = 2 pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet err := cl.WritePacket(pk) @@ -622,7 +618,7 @@ func TestWriteClientOversizePacket(t *testing.T) { } func TestClientReadPacketReadingError(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { r.Write([]byte{ @@ -642,7 +638,7 @@ func TestClientReadPacketReadingError(t *testing.T) { } func TestClientReadPacketReadUnknown(t *testing.T) { - cl, r, _ := newClient() + cl, r, _ := newTestClient() defer cl.Stop(errClientStop) go func() { r.Write([]byte{ @@ -661,7 +657,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) { } func TestClientWritePacketWriteNoConn(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Stop(errClientStop) err := cl.WritePacket(*pkTable[1].Packet) @@ -670,7 +666,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) { } func TestClientWritePacketWriteError(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Net.conn.Close() err := cl.WritePacket(*pkTable[1].Packet) @@ -678,7 +674,7 @@ func TestClientWritePacketWriteError(t *testing.T) { } func TestClientWritePacketInvalidPacket(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := cl.WritePacket(packets.Packet{}) require.Error(t, err) } diff --git a/examples/hooks/main.go b/examples/hooks/main.go index 42577f3..6ac5ea5 100644 --- a/examples/hooks/main.go +++ b/examples/hooks/main.go @@ -52,15 +52,30 @@ func main() { // `server.Publish` method. Subscribe to `direct/publish` using your // MQTT client to see the messages. go func() { - cl := mqtt.NewInlineClient("inline", "local") - for range time.Tick(time.Second * 10) { - server.InjectPacket(cl, packets.Packet{ + cl := server.NewClient(nil, "local", "inline", true) + for range time.Tick(time.Second * 1) { + err := server.InjectPacket(cl, packets.Packet{ FixedHeader: packets.FixedHeader{ Type: packets.Publish, }, TopicName: "direct/publish", - Payload: []byte("scheduled message"), + Payload: []byte("injected scheduled message"), }) + if err != nil { + server.Log.Error().Err(err).Msg("server.InjectPacket") + } + server.Log.Info().Msgf("main.go injected packet to direct/publish") + } + }() + + // There is also a shorthand convenience function, Publish, for easily sending + // publish packets if you are not concerned with creating your own packets. + go func() { + for range time.Tick(time.Second * 5) { + err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) + if err != nil { + server.Log.Error().Err(err).Msg("server.Publish") + } server.Log.Info().Msgf("main.go issued direct message to direct/publish") } }() diff --git a/inflight_test.go b/inflight_test.go index 99a23f7..1028796 100644 --- a/inflight_test.go +++ b/inflight_test.go @@ -13,7 +13,7 @@ import ( ) func TestInflightSet(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() r := cl.State.Inflight.Set(packets.Packet{PacketID: 1}) require.True(t, r) @@ -25,7 +25,7 @@ func TestInflightSet(t *testing.T) { } func TestInflightGet(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: 2}) msg, ok := cl.State.Inflight.Get(2) @@ -34,7 +34,7 @@ func TestInflightGet(t *testing.T) { } func TestInflightGetAllAndImmediate(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) @@ -56,13 +56,13 @@ func TestInflightGetAllAndImmediate(t *testing.T) { } func TestInflightLen(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: 2}) require.Equal(t, 1, cl.State.Inflight.Len()) } func TestInflightDelete(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: 3}) require.NotNil(t, cl.State.Inflight.internal[3]) @@ -163,7 +163,7 @@ func TestSendQuota(t *testing.T) { } func TestNextImmediate(t *testing.T) { - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) diff --git a/server.go b/server.go index d2c2625..8f71bd1 100644 --- a/server.go +++ b/server.go @@ -199,6 +199,31 @@ func (o *Options) ensureDefaults() { } } +// NewClient returns a new Client instance, populated with all the required values and +// references to be used with the server. If you are using this client to directly publish +// messages from the embedding application, set the inline flag to true to bypass ACL and +// topic validation checks. +func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client { + cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit + capabilities: s.Options.Capabilities, + info: s.Info, + hooks: s.hooks, + log: s.Log, + }) + + cl.ID = id + cl.Net.Listener = listener + + if inline { // inline clients bypass acl and some validity checks. + cl.Net.Inline = true + // By default we don't want to restrict developer publishes, + // but if you do, reset this after creating inline client. + cl.State.Inflight.ResetReceiveQuota(math.MaxInt32) + } + + return cl +} + // AddHook attaches a new Hook to the server. Ideally, this should be called // before the server is started with s.Serve(). func (s *Server) AddHook(hook Hook, config any) error { @@ -281,27 +306,21 @@ func (s *Server) eventLoop() { } // EstablishConnection establishes a new client when a listener accepts a new connection. -func (s *Server) EstablishConnection(lid string, c net.Conn) error { - cl := NewClient(c, &ops{ // [MQTT-3.1.2-6] implicit - capabilities: s.Options.Capabilities, - info: s.Info, - hooks: s.hooks, - log: s.Log, - }) - - return s.attachClient(cl, lid) +func (s *Server) EstablishConnection(listener string, c net.Conn) error { + cl := s.NewClient(c, listener, "", false) + return s.attachClient(cl, listener) } // attachClient validates an incoming client connection and if viable, attaches the client // to the server, performs session housekeeping, and reads incoming packets. -func (s *Server) attachClient(cl *Client, lid string) error { +func (s *Server) attachClient(cl *Client, listener string) error { defer cl.Stop(nil) pk, err := s.readConnectionPacket(cl) if err != nil { return fmt.Errorf("read connection: %w", err) } - cl.ParseConnect(lid, pk) + cl.ParseConnect(listener, pk) code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] if code != packets.CodeSuccess { if err := s.sendConnack(cl, code, false); err != nil { @@ -353,7 +372,7 @@ func (s *Server) attachClient(cl *Client, lid string) error { cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] } - s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", lid).Msg("client disconnected") + s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected") expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) s.hooks.OnDisconnect(cl, err, expire) if expire { @@ -592,6 +611,24 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error { }) } +// Publish publishes a publish packet into the broker as if it were sent from the speicfied client. +// This is a convenience function which wraps InjectPacket. As such, this method can publish packets +// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the +// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete). +func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error { + cl := s.NewClient(nil, "local", "inline", true) + return s.InjectPacket(cl, packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Qos: qos, + Retain: retain, + }, + TopicName: topic, + Payload: payload, + PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks. + }) +} + // InjectPacket injects a packet into the broker as if it were sent from the specified client. // InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks. func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { @@ -627,7 +664,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { pk.Origin = cl.ID pk.Created = time.Now().Unix() - if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok { + if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok && !cl.Net.Inline { if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse) return cl.WritePacket(ack) @@ -1087,12 +1124,14 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error { out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1] } - err := cl.WritePacket(out) + // We already have a code we are using to disconnect the client, so we are not + // interested if the write packet fails due to a closed connection (as we are closing it). + _ = cl.WritePacket(out) if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect { cl.Stop(code) } - return err + return code } // publishSysTopics publishes the current values to the server $SYS topics. @@ -1304,9 +1343,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) { // loadClients restores clients from the datastore. func (s *Server) loadClients(v []storage.Client) { for _, c := range v { - cl := newClientStub() - cl.ID = c.ID - cl.Net.Listener = c.Listener + cl := s.NewClient(nil, c.Listener, c.ID, false) cl.Properties.Username = c.Username cl.Properties.Clean = c.Clean cl.Properties.ProtocolVersion = c.ProtocolVersion diff --git a/server_test.go b/server_test.go index dd11f4f..5a9946d 100644 --- a/server_test.go +++ b/server_test.go @@ -102,7 +102,34 @@ func TestNewNilOpts(t *testing.T) { require.NotNil(t, s.Options) } -func TestAddHook(t *testing.T) { +func TestServerNewClient(t *testing.T) { + s := New(nil) + s.Log = &logger + r, _ := net.Pipe() + + cl := s.NewClient(r, "testing", "test", false) + require.NotNil(t, cl) + require.Equal(t, "test", cl.ID) + require.Equal(t, "testing", cl.Net.Listener) + require.False(t, cl.Net.Inline) + require.NotNil(t, cl.State.Inflight.internal) + require.NotNil(t, cl.State.Subscriptions) + require.NotNil(t, cl.State.TopicAliases) + require.Equal(t, defaultKeepalive, cl.State.keepalive) + require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) + require.NotNil(t, cl.Net.conn) + require.NotNil(t, cl.Net.bconn) + require.NotNil(t, cl.ops) + require.Equal(t, s.Log, cl.ops.log) +} + +func TestServerNewClientInline(t *testing.T) { + s := New(nil) + cl := s.NewClient(nil, "testing", "test", true) + require.True(t, cl.Net.Inline) +} + +func TestServerAddHook(t *testing.T) { s := New(nil) s.Log = &logger require.NotNil(t, s) @@ -113,7 +140,7 @@ func TestAddHook(t *testing.T) { require.Equal(t, int64(1), s.hooks.Len()) } -func TestAddListener(t *testing.T) { +func TestServerAddListener(t *testing.T) { s := newServer() defer s.Close() @@ -128,7 +155,7 @@ func TestAddListener(t *testing.T) { require.Equal(t, ErrListenerIDExists, err) } -func TestAddListenerInitFailure(t *testing.T) { +func TestServerAddListenerInitFailure(t *testing.T) { s := newServer() defer s.Close() @@ -197,7 +224,7 @@ func TestServerReadConnectionPacket(t *testing.T) { s := newServer() defer s.Close() - cl, r, _ := newClient() + cl, r, _ := newTestClient() s.Clients.Add(cl) o := make(chan packets.Packet) @@ -219,7 +246,7 @@ func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) { s := newServer() defer s.Close() - cl, r, _ := newClient() + cl, r, _ := newTestClient() s.Clients.Add(cl) o := make(chan error) @@ -242,7 +269,7 @@ func TestServerReadConnectionPacketBadPacketType(t *testing.T) { s := newServer() defer s.Close() - cl, r, _ := newClient() + cl, r, _ := newTestClient() s.Clients.Add(cl) go func() { @@ -259,7 +286,7 @@ func TestServerReadConnectionPacketBadPacket(t *testing.T) { s := newServer() defer s.Close() - cl, r, _ := newClient() + cl, r, _ := newTestClient() s.Clients.Add(cl) go func() { @@ -377,7 +404,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { s := newServer() defer s.Close() - cl, r0, _ := newClient() + cl, r0, _ := newTestClient() cl.Properties.ProtocolVersion = 5 cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) @@ -438,7 +465,7 @@ func TestEstablishConnectionResentPendingInflightsError(t *testing.T) { defer s.Close() n := time.Now().Unix() - cl, r0, _ := newClient() + cl, r0, _ := newTestClient() cl.Properties.ProtocolVersion = 5 cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier cl.State.Inflight = NewInflights() @@ -474,7 +501,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { s := newServer() defer s.Close() - cl, r0, _ := newClient() + cl, r0, _ := newTestClient() cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier cl.Properties.Clean = true cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) @@ -660,7 +687,7 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) { func TestServerSendConnack(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 s.Options.Capabilities.ServerKeepAlive = 20 s.Options.Capabilities.MaximumQos = 1 @@ -680,7 +707,7 @@ func TestServerSendConnack(t *testing.T) { func TestServerSendConnackFailureReason(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 s.Options.Capabilities.ServerKeepAlive = 20 go func() { @@ -758,7 +785,7 @@ func TestServerValidateConnect(t *testing.T) { func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.Properties.Props.SessionExpiryInterval = uint32(300) s.Options.Capabilities.MaximumSessionExpiryInterval = 120 @@ -778,7 +805,7 @@ func TestInheritClientSession(t *testing.T) { n := time.Now().Unix() - existing, _, _ := newClient() + existing, _, _ := newTestClient() existing.Net.conn = nil existing.ID = "mochi" existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) @@ -788,7 +815,7 @@ func TestInheritClientSession(t *testing.T) { s.Clients.Add(existing) - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Properties.ProtocolVersion = 5 require.Equal(t, 0, cl.State.Inflight.Len()) @@ -801,7 +828,7 @@ func TestInheritClientSession(t *testing.T) { require.Equal(t, 1, cl.State.Subscriptions.Len()) // On clean, clear existing properties - cl, _, _ = newClient() + cl, _, _ = newTestClient() cl.Properties.ProtocolVersion = 5 b = s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi", Clean: true}}, cl) require.False(t, b) @@ -811,7 +838,7 @@ func TestInheritClientSession(t *testing.T) { func TestServerUnsubscribeClient(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() pk := packets.Subscription{Filter: "a/b/c", Qos: 1} cl.State.Subscriptions.Add("a/b/c", pk) s.Topics.Subscribe(cl.ID, pk) @@ -824,14 +851,14 @@ func TestServerUnsubscribeClient(t *testing.T) { func TestServerProcessPacketFailure(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := s.processPacket(cl, packets.Packet{}) require.Error(t, err) } func TestServerProcessPacketConnect(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := s.processPacket(cl, *packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet) require.Error(t, err) @@ -839,7 +866,7 @@ func TestServerProcessPacketConnect(t *testing.T) { func TestServerProcessPacketPingreq(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) @@ -854,7 +881,7 @@ func TestServerProcessPacketPingreq(t *testing.T) { func TestServerProcessPacketPingreqError(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Stop(packets.CodeDisconnect) err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet) @@ -864,7 +891,7 @@ func TestServerProcessPacketPingreqError(t *testing.T) { func TestServerProcessPacketPublishInvalid(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishInvalidQosMustPacketID).Packet) require.Error(t, err) @@ -876,12 +903,12 @@ func TestInjectPacketPublishAndReceive(t *testing.T) { s.Serve() defer s.Close() - sender, _, w1 := newClient() + sender, _, w1 := newTestClient() sender.Net.Inline = true sender.ID = "sender" s.Clients.Add(sender) - receiver, r2, w2 := newClient() + receiver, r2, w2 := newTestClient() receiver.ID = "receiver" s.Clients.Add(receiver) s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) @@ -906,10 +933,46 @@ func TestInjectPacketPublishAndReceive(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) } +func TestServerDirectPublishAndReceive(t *testing.T) { + s := newServer() + s.Serve() + defer s.Close() + + sender, _, w1 := newTestClient() + sender.Net.Inline = true + sender.ID = "sender" + s.Clients.Add(sender) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos) + require.NoError(t, err) + w1.Close() + time.Sleep(time.Millisecond * 10) + w2.Close() + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) +} + func TestInjectPacketError(t *testing.T) { s := newServer() defer s.Close() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Net.Inline = true pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet pkx.Filters = packets.Subscriptions{} @@ -920,7 +983,7 @@ func TestInjectPacketError(t *testing.T) { func TestInjectPacketPublishInvalidTopic(t *testing.T) { s := newServer() defer s.Close() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Net.Inline = true pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet pkx.TopicName = "$SYS/test" @@ -933,11 +996,11 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) { s.Serve() defer s.Close() - sender, _, w1 := newClient() + sender, _, w1 := newTestClient() sender.ID = "sender" s.Clients.Add(sender) - receiver, r2, w2 := newClient() + receiver, r2, w2 := newTestClient() receiver.ID = "receiver" s.Clients.Add(receiver) s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) @@ -966,7 +1029,7 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) { func TestServerProcessPacketAndNextImmediate(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() next := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet next.Expiry = -1 @@ -993,7 +1056,7 @@ func TestServerProcessPacketPublishAckFailure(t *testing.T) { s.Serve() defer s.Close() - cl, _, w := newClient() + cl, _, w := newTestClient() s.Clients.Add(cl) w.Close() @@ -1007,14 +1070,15 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { s.Serve() defer s.Close() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.State.Inflight.ResetReceiveQuota(0) s.Clients.Add(cl) go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) - require.NoError(t, err) + require.Error(t, err) + require.ErrorIs(t, err, packets.ErrReceiveMaximum) w.Close() }() @@ -1027,7 +1091,7 @@ func TestServerProcessPublishInvalidTopic(t *testing.T) { s := newServer() s.Serve() defer s.Close() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet) require.NoError(t, err) // $SYS topics should be ignored? } @@ -1040,7 +1104,7 @@ func TestServerProcessPublishACLCheckDeny(t *testing.T) { }) s.Serve() defer s.Close() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) // ACL check fails silently } @@ -1057,14 +1121,14 @@ func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) { s.Serve() defer s.Close() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) require.NoError(t, err) // packets rejected silently } func TestServerProcessPacketPublishQos0(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) @@ -1079,7 +1143,7 @@ func TestServerProcessPacketPublishQos0(t *testing.T) { func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) atomic.StoreInt64(&s.Info.Inflight, 1) @@ -1097,7 +1161,7 @@ func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) { func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Pubrec}}) atomic.StoreInt64(&s.Info.Inflight, 1) @@ -1116,7 +1180,7 @@ func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) { func TestServerProcessPacketPublishQos1(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) @@ -1131,7 +1195,7 @@ func TestServerProcessPacketPublishQos1(t *testing.T) { func TestServerProcessPacketPublishQos2(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) @@ -1147,7 +1211,7 @@ func TestServerProcessPacketPublishQos2(t *testing.T) { func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { s := newServer() s.Options.Capabilities.MaximumQos = 1 - cl, r, w := newClient() + cl, r, w := newTestClient() go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet) @@ -1162,7 +1226,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { func TestPublishToSubscribersSelfNoLocal(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", NoLocal: true}) require.True(t, subbed) @@ -1187,11 +1251,11 @@ func TestPublishToSubscribersSelfNoLocal(t *testing.T) { func TestPublishToSubscribers(t *testing.T) { s := newServer() - cl, r1, w1 := newClient() + cl, r1, w1 := newTestClient() cl.ID = "cl1" - cl2, r2, w2 := newClient() + cl2, r2, w2 := newTestClient() cl2.ID = "cl2" - cl3, r3, w3 := newClient() + cl3, r3, w3 := newTestClient() cl3.ID = "cl3" s.Clients.Add(cl) s.Clients.Add(cl2) @@ -1249,7 +1313,7 @@ func TestPublishToSubscribers(t *testing.T) { func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { s := newServer() s.Options.Capabilities.MaximumMessageExpiryInterval = 86400 - cl, r1, w1 := newClient() + cl, r1, w1 := newTestClient() cl.ID = "cl1" cl.Properties.ProtocolVersion = 5 s.Clients.Add(cl) @@ -1278,7 +1342,7 @@ func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { func TestPublishToSubscribersIdentifiers(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 s.Clients.Add(cl) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/+", Identifier: 2}) @@ -1308,7 +1372,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) { s := newServer() s.Options.Capabilities.MaximumQos = 1 - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) _, ok := cl.State.Inflight.Get(1) @@ -1334,7 +1398,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) { func TestPublishToClientServerTopicAlias(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.Properties.Props.TopicAliasMaximum = 5 s.Clients.Add(cl) @@ -1363,7 +1427,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) { func TestPublishToClientExhaustedPacketID(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() for i := 0; i <= 65535; i++ { cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) } @@ -1375,7 +1439,7 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) { func TestPublishToClientNoConn(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Net.conn = nil _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) @@ -1385,12 +1449,12 @@ func TestPublishToClientNoConn(t *testing.T) { func TestProcessPublishWithTopicAlias(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) require.True(t, subbed) - cl2, _, w2 := newClient() + cl2, _, w2 := newTestClient() cl2.Properties.ProtocolVersion = 5 cl2.State.TopicAliases.Inbound.Set(1, "a/b/c") @@ -1412,7 +1476,7 @@ func TestProcessPublishWithTopicAlias(t *testing.T) { func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) cl.State.Inflight.sendQuota = 0 @@ -1431,7 +1495,7 @@ func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) for i := 0; i <= 65535; i++ { cl.State.Inflight.Set(packets.Packet{PacketID: 1}) @@ -1452,7 +1516,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { func TestPublishToSubscribersNoConnection(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) require.True(t, subbed) @@ -1467,7 +1531,7 @@ func TestPublishToSubscribersNoConnection(t *testing.T) { func TestPublishRetainedToClient(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) @@ -1489,7 +1553,7 @@ func TestPublishRetainedToClient(t *testing.T) { func TestPublishRetainedToClientIsShared(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) sub := packets.Subscription{Filter: SharePrefix + "/test/a/b/c"} @@ -1508,7 +1572,7 @@ func TestPublishRetainedToClientIsShared(t *testing.T) { func TestPublishRetainedToClientError(t *testing.T) { s := newServer() - cl, _, w := newClient() + cl, _, w := newTestClient() s.Clients.Add(cl) sub := packets.Subscription{Filter: "a/b/c"} @@ -1538,7 +1602,7 @@ func TestServerProcessPacketPuback(t *testing.T) { t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) { pID := uint16(7) s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1560,7 +1624,7 @@ func TestServerProcessPacketPuback(t *testing.T) { func TestServerProcessPacketPubackNoPacketID(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1575,7 +1639,7 @@ func TestServerProcessPacketPubackNoPacketID(t *testing.T) { func TestServerProcessPacketPubrec(t *testing.T) { pID := uint16(7) s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1604,7 +1668,7 @@ func TestServerProcessPacketPubrec(t *testing.T) { func TestServerProcessPacketPubrecNoPacketID(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1630,7 +1694,7 @@ func TestServerProcessPacketPubrecNoPacketID(t *testing.T) { func TestServerProcessPacketPubrecInvalidReason(t *testing.T) { pID := uint16(7) s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: pID}) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrecInvalidReason).Packet) require.NoError(t, err) @@ -1642,7 +1706,7 @@ func TestServerProcessPacketPubrecInvalidReason(t *testing.T) { func TestServerProcessPacketPubrecFailure(t *testing.T) { pID := uint16(7) s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: pID}) cl.Stop(packets.CodeDisconnect) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet) @@ -1653,7 +1717,7 @@ func TestServerProcessPacketPubrecFailure(t *testing.T) { func TestServerProcessPacketPubrel(t *testing.T) { pID := uint16(7) s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1683,7 +1747,7 @@ func TestServerProcessPacketPubrel(t *testing.T) { func TestServerProcessPacketPubrelNoPacketID(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1709,7 +1773,7 @@ func TestServerProcessPacketPubrelNoPacketID(t *testing.T) { func TestServerProcessPacketPubrelFailure(t *testing.T) { pID := uint16(7) s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: pID}) cl.Stop(packets.CodeDisconnect) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet) @@ -1720,7 +1784,7 @@ func TestServerProcessPacketPubrelFailure(t *testing.T) { func TestServerProcessPacketPubrelBadReason(t *testing.T) { pID := uint16(7) s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.Inflight.Set(packets.Packet{PacketID: pID}) err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrelInvalidReason).Packet) require.NoError(t, err) @@ -1745,7 +1809,7 @@ func TestServerProcessPacketPubcomp(t *testing.T) { t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) { pID := uint16(7) s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Properties.ProtocolVersion = tx.protocolVersion cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1792,7 +1856,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) { pID := uint16(7) s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1863,7 +1927,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { pID := uint16(6) s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.State.packetID = uint32(6) cl.State.Inflight.sendQuota = 3 cl.State.Inflight.receiveQuota = 3 @@ -1907,7 +1971,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { func TestServerProcessPacketSubscribe(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet) @@ -1922,7 +1986,7 @@ func TestServerProcessPacketSubscribe(t *testing.T) { func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) @@ -1941,7 +2005,7 @@ func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) { func TestServerProcessPacketSubscribeInvalid(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Properties.ProtocolVersion = 5 err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeSpecQosMustPacketID).Packet) @@ -1951,7 +2015,7 @@ func TestServerProcessPacketSubscribeInvalid(t *testing.T) { func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { @@ -1967,7 +2031,7 @@ func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) { func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { @@ -1983,7 +2047,7 @@ func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) { func TestServerProcessSubscribeWithRetain(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) require.Equal(t, int64(1), retained) @@ -2007,7 +2071,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) { func TestServerProcessSubscribeDowngradeQos(t *testing.T) { s := newServer() s.Options.Capabilities.MaximumQos = 1 - cl, r, w := newClient() + cl, r, w := newTestClient() go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMany).Packet) @@ -2024,7 +2088,7 @@ func TestServerProcessSubscribeDowngradeQos(t *testing.T) { func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"}) s.Clients.Add(cl) @@ -2046,7 +2110,7 @@ func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) { func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) @@ -2067,7 +2131,7 @@ func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) { func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() s.Clients.Add(cl) retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) @@ -2091,7 +2155,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) { func TestServerProcessSubscribeNoConnection(t *testing.T) { s := newServer() - cl, r, _ := newClient() + cl, r, _ := newTestClient() r.Close() err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet) require.Error(t, err) @@ -2105,7 +2169,7 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { FanPoolQueueSize: 10, }) s.Serve() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { @@ -2127,7 +2191,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { }) s.Serve() s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 go func() { @@ -2143,7 +2207,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { func TestServerProcessSubscribeErrorDowngrade(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 3 cl.State.packetID = 1 // just to match the same packet id (7) in the fixtures @@ -2160,7 +2224,7 @@ func TestServerProcessSubscribeErrorDowngrade(t *testing.T) { func TestServerProcessPacketUnsubscribe(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b", Qos: 0}) go func() { @@ -2177,7 +2241,7 @@ func TestServerProcessPacketUnsubscribe(t *testing.T) { func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.ProtocolVersion = 5 cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}}) go func() { @@ -2194,7 +2258,7 @@ func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) { func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) @@ -2202,7 +2266,7 @@ func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) { func TestServerReceivePacketError(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() err := s.receivePacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID) @@ -2210,7 +2274,7 @@ func TestServerReceivePacketError(t *testing.T) { func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() cl.Properties.Props.SessionExpiryInterval = 0 cl.Properties.ProtocolVersion = 5 cl.Properties.Props.RequestProblemInfo = 0 @@ -2229,7 +2293,7 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) { func TestServerProcessPacketDisconnect(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Properties.Props.SessionExpiryInterval = 30 cl.Properties.ProtocolVersion = 5 @@ -2246,7 +2310,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) { func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.Properties.Props.SessionExpiryInterval = 0 cl.Properties.ProtocolVersion = 5 cl.Properties.Props.RequestProblemInfo = 0 @@ -2259,7 +2323,7 @@ func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) { func TestServerProcessPacketAuth(t *testing.T) { s := newServer() - cl, r, w := newClient() + cl, r, w := newTestClient() go func() { err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet) @@ -2274,7 +2338,7 @@ func TestServerProcessPacketAuth(t *testing.T) { func TestServerProcessPacketAuthInvalidReason(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() pkx := *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet pkx.ReasonCode = 99 err := s.processPacket(cl, pkx) @@ -2284,7 +2348,7 @@ func TestServerProcessPacketAuthInvalidReason(t *testing.T) { func TestServerProcessPacketAuthFailure(t *testing.T) { s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() hook := new(modifiedHookBase) hook.fail = true @@ -2301,7 +2365,7 @@ func TestServerSendLWT(t *testing.T) { s.Serve() defer s.Close() - sender, _, w1 := newClient() + sender, _, w1 := newTestClient() sender.ID = "sender" sender.Properties.Will = Will{ Flag: 1, @@ -2310,7 +2374,7 @@ func TestServerSendLWT(t *testing.T) { } s.Clients.Add(sender) - receiver, r2, w2 := newClient() + receiver, r2, w2 := newTestClient() receiver.ID = "receiver" s.Clients.Add(receiver) s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) @@ -2337,7 +2401,7 @@ func TestServerSendLWT(t *testing.T) { func TestServerSendLWTDelayed(t *testing.T) { s := newServer() - cl1, _, _ := newClient() + cl1, _, _ := newTestClient() cl1.ID = "cl1" cl1.Properties.Will = Will{ Flag: 1, @@ -2348,7 +2412,7 @@ func TestServerSendLWTDelayed(t *testing.T) { } s.Clients.Add(cl1) - cl2, r, w := newClient() + cl2, r, w := newTestClient() cl2.ID = "cl2" s.Clients.Add(cl2) require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: "a/b/c"})) @@ -2426,7 +2490,7 @@ func TestServerLoadSubscriptions(t *testing.T) { } s := newServer() - cl, _, _ := newClient() + cl, _, _ := newTestClient() s.Clients.Add(cl) require.Equal(t, 0, cl.State.Subscriptions.Len()) s.loadSubscriptions(v) @@ -2486,7 +2550,7 @@ func TestServerClose(t *testing.T) { hook := new(modifiedHookBase) s.AddHook(hook, nil) - cl, r, _ := newClient() + cl, r, _ := newTestClient() cl.Net.Listener = "t1" cl.Properties.ProtocolVersion = 5 s.Clients.Add(cl) @@ -2524,7 +2588,7 @@ func TestServerClearExpiredInflights(t *testing.T) { s.Options.Capabilities.MaximumMessageExpiryInterval = 4 n := time.Now().Unix() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.ops.info = s.Info cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) @@ -2564,12 +2628,12 @@ func TestServerClearExpiredClients(t *testing.T) { n := time.Now().Unix() - cl, _, _ := newClient() + cl, _, _ := newTestClient() cl.ID = "cl" s.Clients.Add(cl) // No Expiry - cl0, _, _ := newClient() + cl0, _, _ := newTestClient() cl0.ID = "c0" cl0.State.disconnected = n - 10 cl0.State.done = 1 @@ -2579,7 +2643,7 @@ func TestServerClearExpiredClients(t *testing.T) { s.Clients.Add(cl0) // Normal Expiry - cl1, _, _ := newClient() + cl1, _, _ := newTestClient() cl1.ID = "c1" cl1.State.disconnected = n - 10 cl1.State.done = 1 @@ -2589,7 +2653,7 @@ func TestServerClearExpiredClients(t *testing.T) { s.Clients.Add(cl1) // No Expiry, indefinite session - cl2, _, _ := newClient() + cl2, _, _ := newTestClient() cl2.ID = "c2" cl2.State.disconnected = n - 10 cl2.State.done = 1