diff --git a/examples/paho.testing/main.go b/examples/paho.testing/main.go index b95ce08..82ab085 100644 --- a/examples/paho.testing/main.go +++ b/examples/paho.testing/main.go @@ -28,6 +28,7 @@ func main() { server := mqtt.New(nil) server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true + server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true _ = server.AddHook(new(pahoAuthHook), nil) tcp := listeners.NewTCP("t1", ":1883", nil) diff --git a/packets/codes.go b/packets/codes.go index f84f09d..7e314de 100644 --- a/packets/codes.go +++ b/packets/codes.go @@ -28,6 +28,7 @@ var ( 2: CodeGrantedQos2, } + CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"} CodeSuccess = Code{Code: 0x00, Reason: "success"} CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"} CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"} diff --git a/packets/packets.go b/packets/packets.go index 2cef27a..e53fe12 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -135,6 +135,7 @@ type Packet struct { SessionPresent bool // session existed for connack ReasonCode byte // reason code for a packet response (acks, etc) ReservedBit byte // reserved, do not use (except in testing) + Ignore bool // if true, do not perform any message forwarding operations } // Mods specifies certain values required for certain mqtt v5 compliance within packet encoding/decoding. diff --git a/packets/tpackets.go b/packets/tpackets.go index b16eab0..b86e097 100644 --- a/packets/tpackets.go +++ b/packets/tpackets.go @@ -103,6 +103,7 @@ const ( TPublishBasicMqtt5 TPublishMqtt5 TPublishQos1 + TPublishQos1Mqtt5 TPublishQos1NoPayload TPublishQos1Dup TPublishQos2 @@ -132,6 +133,7 @@ const ( TPubackMqtt5 TPubackMalPacketID TPubackMalProperties + TPubackUnexpectedError TPubrec TPubrecMqtt5 TPubrecMqtt5IDInUse @@ -1704,6 +1706,43 @@ var TPacketData = map[byte]TPacketCases{ PacketID: 7, }, }, + { + Case: TPublishQos1Mqtt5, + Desc: "mqtt v5", + Primary: true, + RawBytes: []byte{ + Publish<<4 | 1<<1, 37, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 7, // Packet ID - LSB+MSB + // Properties + 16, // length + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 37, + Qos: 1, + }, + PacketID: 7, + TopicName: "a/b/c", + Properties: Properties{ + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, + }, + Payload: []byte("hello mochi"), + }, + }, + { Case: TPublishQos1Dup, Desc: "qos:1, dup:true, packet id", @@ -2235,6 +2274,32 @@ var TPacketData = map[byte]TPacketCases{ }, }, }, + { + Case: TPubackUnexpectedError, + Desc: "unexpected error", + Group: "decode", + RawBytes: []byte{ + Puback << 4, 29, // Fixed header + 0, 7, // Packet ID - LSB+MSB + ErrPayloadFormatInvalid.Code, // Reason Code + 25, // Properties Length + 31, 0, 22, 'p', 'a', 'y', 'l', 'o', 'a', 'd', + ' ', 'f', 'o', 'r', 'm', 'a', 't', + ' ', 'i', 'n', 'v', 'a', 'l', 'i', 'd', // Reason String (31) + }, + Packet: &Packet{ + ProtocolVersion: 5, + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 28, + }, + PacketID: 7, + ReasonCode: ErrPayloadFormatInvalid.Code, + Properties: Properties{ + ReasonString: ErrPayloadFormatInvalid.Reason, + }, + }, + }, // Fail states { @@ -2316,14 +2381,17 @@ var TPacketData = map[byte]TPacketCases{ Desc: "packet id in use mqtt5", Primary: true, RawBytes: []byte{ - Pubrec << 4, 31, // Fixed header + Pubrec << 4, 47, // Fixed header 0, 7, // Packet ID - LSB+MSB ErrPacketIdentifierInUse.Code, // Reason Code - 27, // Properties Length + 43, // Properties Length 31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't', ' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r', ' ', 'i', 'n', ' ', 'u', 's', 'e', // Reason String (31) + 38, // User Properties (38) + 0, 5, 'h', 'e', 'l', 'l', 'o', + 0, 6, 228, 184, 150, 231, 149, 140, }, Packet: &Packet{ ProtocolVersion: 5, @@ -2335,6 +2403,12 @@ var TPacketData = map[byte]TPacketCases{ ReasonCode: ErrPacketIdentifierInUse.Code, Properties: Properties{ ReasonString: ErrPacketIdentifierInUse.Reason, + User: []UserProperty{ + { + Key: "hello", + Val: "世界", + }, + }, }, }, }, diff --git a/server.go b/server.go index e4f4556..df2dcfc 100644 --- a/server.go +++ b/server.go @@ -71,10 +71,11 @@ type Capabilities struct { // Compatibilities provides flags for using compatibility modes. type Compatibilities struct { - ObscureNotAuthorized bool // return unspecified errors instead of not authorized - PassiveClientDisconnect bool // don't disconnect the client forcefully after sending disconnect packet (paho) - AlwaysReturnResponseInfo bool // always return response info (useful for testing) - RestoreSysInfoOnRestart bool // restore system info from store as if server never stopped + ObscureNotAuthorized bool // return unspecified errors instead of not authorized + PassiveClientDisconnect bool // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation) + AlwaysReturnResponseInfo bool // always return response info (useful for testing) + RestoreSysInfoOnRestart bool // restore system info from store as if server never stopped + NoInheritedPropertiesOnAck bool // don't allow inherited user properties on ack (paho - spec violation) } // Options contains configurable options for the server. @@ -715,10 +716,19 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce Qos based on server max qos capability } - if pkx, err := s.hooks.OnPublish(cl, pk); err == nil { + pkx, err := s.hooks.OnPublish(cl, pk) + if err == nil { pk = pkx } else if errors.Is(err, packets.ErrRejectPacket) { return nil + } else if errors.Is(err, packets.CodeSuccessIgnore) { + pk.Ignore = true + } else if cl.Properties.ProtocolVersion == 5 && pk.FixedHeader.Qos > 0 && errors.As(err, new(packets.Code)) { + err = cl.WritePacket(s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, err.(packets.Code))) + if err != nil { + return err + } + return nil } if pk.FixedHeader.Retain { // [MQTT-3.3.1-5] ![MQTT-3.3.1-8] @@ -742,7 +752,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { s.hooks.OnQosPublish(cl, ack, ack.Created, 0) } - err := cl.WritePacket(ack) + err = cl.WritePacket(ack) if err != nil { return err } @@ -764,7 +774,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { // retainMessage adds a message to a topic, and if a persistent store is provided, // adds the message to the store to be reloaded if necessary. func (s *Server) retainMessage(cl *Client, pk packets.Packet) { - if s.Options.Capabilities.RetainAvailable == 0 { + if s.Options.Capabilities.RetainAvailable == 0 || pk.Ignore { return } @@ -776,6 +786,10 @@ func (s *Server) retainMessage(cl *Client, pk packets.Packet) { // publishToSubscribers publishes a publish packet to all subscribers with matching topic filters. func (s *Server) publishToSubscribers(pk packets.Packet) { + if pk.Ignore { + return + } + if pk.Created == 0 { pk.Created = time.Now().Unix() } @@ -905,7 +919,9 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e // buildAck builds a standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets. func (s *Server) buildAck(packetID uint16, pkt, qos byte, properties packets.Properties, reason packets.Code) packets.Packet { - properties = packets.Properties{} // PRL + if s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck { + properties = packets.Properties{} + } if reason.Code >= packets.ErrUnspecifiedError.Code { properties.ReasonString = reason.Reason } diff --git a/server_test.go b/server_test.go index c16bac9..7a97904 100644 --- a/server_test.go +++ b/server_test.go @@ -1198,6 +1198,50 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) { require.Equal(t, 1, len(s.Topics.Messages("a/b/c"))) } +func TestServerBuildAck(t *testing.T) { + s := newServer() + properties := packets.Properties{ + User: []packets.UserProperty{ + {Key: "hello", Val: "世界"}, + }, + } + ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1) + require.Equal(t, packets.Puback, ack.FixedHeader.Type) + require.Equal(t, uint8(1), ack.FixedHeader.Qos) + require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode) + require.Equal(t, properties, ack.Properties) +} + +func TestServerBuildAckError(t *testing.T) { + s := newServer() + properties := packets.Properties{ + User: []packets.UserProperty{ + {Key: "hello", Val: "世界"}, + }, + } + ack := s.buildAck(7, packets.Puback, 1, properties, packets.ErrMalformedPacket) + require.Equal(t, packets.Puback, ack.FixedHeader.Type) + require.Equal(t, uint8(1), ack.FixedHeader.Qos) + require.Equal(t, packets.ErrMalformedPacket.Code, ack.ReasonCode) + properties.ReasonString = packets.ErrMalformedPacket.Reason + require.Equal(t, properties, ack.Properties) +} + +func TestServerBuildAckPahoCompatibility(t *testing.T) { + s := newServer() + s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true + properties := packets.Properties{ + User: []packets.UserProperty{ + {Key: "hello", Val: "世界"}, + }, + } + ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1) + require.Equal(t, packets.Puback, ack.FixedHeader.Type) + require.Equal(t, uint8(1), ack.FixedHeader.Qos) + require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode) + require.Equal(t, packets.Properties{}, ack.Properties) +} + func TestServerProcessPacketAndNextImmediate(t *testing.T) { s := newServer() cl, r, w := newTestClient() @@ -1222,7 +1266,7 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) { require.Equal(t, int32(4), cl.State.Inflight.sendQuota) } -func TestServerProcessPacketPublishAckFailure(t *testing.T) { +func TestServerProcessPublishAckFailure(t *testing.T) { s := newServer() s.Serve() defer s.Close() @@ -1236,6 +1280,92 @@ func TestServerProcessPacketPublishAckFailure(t *testing.T) { require.ErrorIs(t, err, io.ErrClosedPipe) } +func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + hook.fail = true + hook.err = packets.ErrUnspecifiedError + err := s.AddHook(hook, nil) + require.NoError(t,err) + + cl, _, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + s.Clients.Add(cl) + w.Close() + + err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.Error(t, err) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + hook.fail = true + hook.err = packets.ErrPayloadFormatInvalid + err := s.AddHook(hook, nil) + require.NoError(t,err) + s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + cl.Properties.ProtocolVersion = 5 + s.Clients.Add(cl) + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.NoError(t, err) + w.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPubackUnexpectedError).RawBytes, buf) +} + +func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) { + s := newServer() + hook := new(modifiedHookBase) + hook.fail = true + hook.err = packets.CodeSuccessIgnore + err := s.AddHook(hook, nil) + require.NoError(t,err) + s.Serve() + defer s.Close() + + cl, r, w := newTestClient() + s.Clients.Add(cl) + + receiver, r2, w2 := newTestClient() + receiver.ID = "receiver" + s.Clients.Add(receiver) + s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"}) + + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) + require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r2) + require.NoError(t, err) + receiverBuf <- buf + }() + + + go func() { + err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + require.NoError(t, err) + w.Close() + w2.Close() + }() + + buf, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf) + require.Equal(t, []byte{}, <-receiverBuf) + require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) +} + func TestServerProcessPacketPublishMaximumReceive(t *testing.T) { s := newServer() s.Serve() @@ -1393,6 +1523,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf) } + func TestPublishToSubscribersSelfNoLocal(t *testing.T) { s := newServer() cl, r, w := newTestClient() @@ -1537,6 +1668,32 @@ func TestPublishToSubscribersIdentifiers(t *testing.T) { require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishSubscriberIdentifier).RawBytes, <-receiverBuf) } +func TestPublishToSubscribersPkIgnore(t *testing.T) { + s := newServer() + cl, r, w := newTestClient() + s.Clients.Add(cl) + subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "#", Identifier: 1}) + require.True(t, subbed) + + go func() { + pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet + pk.Ignore = true + s.publishToSubscribers(pk) + time.Sleep(time.Millisecond) + w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + require.Equal(t, []byte{}, <-receiverBuf) +} + + func TestPublishToClientServerDowngradeQos(t *testing.T) { s := newServer() s.Options.Capabilities.MaximumQos = 1 @@ -1846,6 +2003,27 @@ func TestNoRetainMessageIfUnavailable(t *testing.T) { require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained)) } + +func TestNoRetainMessageIfPkIgnore(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + s.Clients.Add(cl) + + pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet + pk.Ignore = true + s.retainMessage(new(Client), pk) + require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained)) +} + +func TestNoRetainMessage(t *testing.T) { + s := newServer() + cl, _, _ := newTestClient() + s.Clients.Add(cl) + + s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) + require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Retained)) +} + func TestServerProcessPacketPuback(t *testing.T) { tt := ProtocolTest{ {