mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-27 04:26:23 +08:00
Allow Publish to return custom Ack error responses (#256)
* Allow publish error returns as acks * Add Ignore Packet, tests
This commit is contained in:
@@ -28,6 +28,7 @@ func main() {
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
||||
server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
|
||||
|
||||
_ = server.AddHook(new(pahoAuthHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
|
@@ -28,6 +28,7 @@ var (
|
||||
2: CodeGrantedQos2,
|
||||
}
|
||||
|
||||
CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"}
|
||||
CodeSuccess = Code{Code: 0x00, Reason: "success"}
|
||||
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
|
||||
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
|
||||
|
@@ -135,6 +135,7 @@ type Packet struct {
|
||||
SessionPresent bool // session existed for connack
|
||||
ReasonCode byte // reason code for a packet response (acks, etc)
|
||||
ReservedBit byte // reserved, do not use (except in testing)
|
||||
Ignore bool // if true, do not perform any message forwarding operations
|
||||
}
|
||||
|
||||
// Mods specifies certain values required for certain mqtt v5 compliance within packet encoding/decoding.
|
||||
|
@@ -103,6 +103,7 @@ const (
|
||||
TPublishBasicMqtt5
|
||||
TPublishMqtt5
|
||||
TPublishQos1
|
||||
TPublishQos1Mqtt5
|
||||
TPublishQos1NoPayload
|
||||
TPublishQos1Dup
|
||||
TPublishQos2
|
||||
@@ -132,6 +133,7 @@ const (
|
||||
TPubackMqtt5
|
||||
TPubackMalPacketID
|
||||
TPubackMalProperties
|
||||
TPubackUnexpectedError
|
||||
TPubrec
|
||||
TPubrecMqtt5
|
||||
TPubrecMqtt5IDInUse
|
||||
@@ -1704,6 +1706,43 @@ var TPacketData = map[byte]TPacketCases{
|
||||
PacketID: 7,
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPublishQos1Mqtt5,
|
||||
Desc: "mqtt v5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Publish<<4 | 1<<1, 37, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
// Properties
|
||||
16, // length
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Remaining: 37,
|
||||
Qos: 1,
|
||||
},
|
||||
PacketID: 7,
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TPublishQos1Dup,
|
||||
Desc: "qos:1, dup:true, packet id",
|
||||
@@ -2235,6 +2274,32 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TPubackUnexpectedError,
|
||||
Desc: "unexpected error",
|
||||
Group: "decode",
|
||||
RawBytes: []byte{
|
||||
Puback << 4, 29, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrPayloadFormatInvalid.Code, // Reason Code
|
||||
25, // Properties Length
|
||||
31, 0, 22, 'p', 'a', 'y', 'l', 'o', 'a', 'd',
|
||||
' ', 'f', 'o', 'r', 'm', 'a', 't',
|
||||
' ', 'i', 'n', 'v', 'a', 'l', 'i', 'd', // Reason String (31)
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Puback,
|
||||
Remaining: 28,
|
||||
},
|
||||
PacketID: 7,
|
||||
ReasonCode: ErrPayloadFormatInvalid.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrPayloadFormatInvalid.Reason,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Fail states
|
||||
{
|
||||
@@ -2316,14 +2381,17 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "packet id in use mqtt5",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Pubrec << 4, 31, // Fixed header
|
||||
Pubrec << 4, 47, // Fixed header
|
||||
0, 7, // Packet ID - LSB+MSB
|
||||
ErrPacketIdentifierInUse.Code, // Reason Code
|
||||
27, // Properties Length
|
||||
43, // Properties Length
|
||||
31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't',
|
||||
' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r',
|
||||
' ', 'i', 'n',
|
||||
' ', 'u', 's', 'e', // Reason String (31)
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
@@ -2335,6 +2403,12 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ReasonCode: ErrPacketIdentifierInUse.Code,
|
||||
Properties: Properties{
|
||||
ReasonString: ErrPacketIdentifierInUse.Reason,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
32
server.go
32
server.go
@@ -71,10 +71,11 @@ type Capabilities struct {
|
||||
|
||||
// Compatibilities provides flags for using compatibility modes.
|
||||
type Compatibilities struct {
|
||||
ObscureNotAuthorized bool // return unspecified errors instead of not authorized
|
||||
PassiveClientDisconnect bool // don't disconnect the client forcefully after sending disconnect packet (paho)
|
||||
AlwaysReturnResponseInfo bool // always return response info (useful for testing)
|
||||
RestoreSysInfoOnRestart bool // restore system info from store as if server never stopped
|
||||
ObscureNotAuthorized bool // return unspecified errors instead of not authorized
|
||||
PassiveClientDisconnect bool // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation)
|
||||
AlwaysReturnResponseInfo bool // always return response info (useful for testing)
|
||||
RestoreSysInfoOnRestart bool // restore system info from store as if server never stopped
|
||||
NoInheritedPropertiesOnAck bool // don't allow inherited user properties on ack (paho - spec violation)
|
||||
}
|
||||
|
||||
// Options contains configurable options for the server.
|
||||
@@ -715,10 +716,19 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
pk.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] Reduce Qos based on server max qos capability
|
||||
}
|
||||
|
||||
if pkx, err := s.hooks.OnPublish(cl, pk); err == nil {
|
||||
pkx, err := s.hooks.OnPublish(cl, pk)
|
||||
if err == nil {
|
||||
pk = pkx
|
||||
} else if errors.Is(err, packets.ErrRejectPacket) {
|
||||
return nil
|
||||
} else if errors.Is(err, packets.CodeSuccessIgnore) {
|
||||
pk.Ignore = true
|
||||
} else if cl.Properties.ProtocolVersion == 5 && pk.FixedHeader.Qos > 0 && errors.As(err, new(packets.Code)) {
|
||||
err = cl.WritePacket(s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, err.(packets.Code)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Retain { // [MQTT-3.3.1-5] ![MQTT-3.3.1-8]
|
||||
@@ -742,7 +752,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
s.hooks.OnQosPublish(cl, ack, ack.Created, 0)
|
||||
}
|
||||
|
||||
err := cl.WritePacket(ack)
|
||||
err = cl.WritePacket(ack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -764,7 +774,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
// retainMessage adds a message to a topic, and if a persistent store is provided,
|
||||
// adds the message to the store to be reloaded if necessary.
|
||||
func (s *Server) retainMessage(cl *Client, pk packets.Packet) {
|
||||
if s.Options.Capabilities.RetainAvailable == 0 {
|
||||
if s.Options.Capabilities.RetainAvailable == 0 || pk.Ignore {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -776,6 +786,10 @@ func (s *Server) retainMessage(cl *Client, pk packets.Packet) {
|
||||
|
||||
// publishToSubscribers publishes a publish packet to all subscribers with matching topic filters.
|
||||
func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
if pk.Ignore {
|
||||
return
|
||||
}
|
||||
|
||||
if pk.Created == 0 {
|
||||
pk.Created = time.Now().Unix()
|
||||
}
|
||||
@@ -905,7 +919,9 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e
|
||||
|
||||
// buildAck builds a standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets.
|
||||
func (s *Server) buildAck(packetID uint16, pkt, qos byte, properties packets.Properties, reason packets.Code) packets.Packet {
|
||||
properties = packets.Properties{} // PRL
|
||||
if s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck {
|
||||
properties = packets.Properties{}
|
||||
}
|
||||
if reason.Code >= packets.ErrUnspecifiedError.Code {
|
||||
properties.ReasonString = reason.Reason
|
||||
}
|
||||
|
180
server_test.go
180
server_test.go
@@ -1198,6 +1198,50 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
|
||||
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
|
||||
}
|
||||
|
||||
func TestServerBuildAck(t *testing.T) {
|
||||
s := newServer()
|
||||
properties := packets.Properties{
|
||||
User: []packets.UserProperty{
|
||||
{Key: "hello", Val: "世界"},
|
||||
},
|
||||
}
|
||||
ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
|
||||
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
|
||||
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
|
||||
require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
|
||||
require.Equal(t, properties, ack.Properties)
|
||||
}
|
||||
|
||||
func TestServerBuildAckError(t *testing.T) {
|
||||
s := newServer()
|
||||
properties := packets.Properties{
|
||||
User: []packets.UserProperty{
|
||||
{Key: "hello", Val: "世界"},
|
||||
},
|
||||
}
|
||||
ack := s.buildAck(7, packets.Puback, 1, properties, packets.ErrMalformedPacket)
|
||||
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
|
||||
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
|
||||
require.Equal(t, packets.ErrMalformedPacket.Code, ack.ReasonCode)
|
||||
properties.ReasonString = packets.ErrMalformedPacket.Reason
|
||||
require.Equal(t, properties, ack.Properties)
|
||||
}
|
||||
|
||||
func TestServerBuildAckPahoCompatibility(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
|
||||
properties := packets.Properties{
|
||||
User: []packets.UserProperty{
|
||||
{Key: "hello", Val: "世界"},
|
||||
},
|
||||
}
|
||||
ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
|
||||
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
|
||||
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
|
||||
require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
|
||||
require.Equal(t, packets.Properties{}, ack.Properties)
|
||||
}
|
||||
|
||||
func TestServerProcessPacketAndNextImmediate(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newTestClient()
|
||||
@@ -1222,7 +1266,7 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) {
|
||||
require.Equal(t, int32(4), cl.State.Inflight.sendQuota)
|
||||
}
|
||||
|
||||
func TestServerProcessPacketPublishAckFailure(t *testing.T) {
|
||||
func TestServerProcessPublishAckFailure(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
@@ -1236,6 +1280,92 @@ func TestServerProcessPacketPublishAckFailure(t *testing.T) {
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
}
|
||||
|
||||
func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) {
|
||||
s := newServer()
|
||||
hook := new(modifiedHookBase)
|
||||
hook.fail = true
|
||||
hook.err = packets.ErrUnspecifiedError
|
||||
err := s.AddHook(hook, nil)
|
||||
require.NoError(t,err)
|
||||
|
||||
cl, _, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Clients.Add(cl)
|
||||
w.Close()
|
||||
|
||||
err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
}
|
||||
|
||||
func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) {
|
||||
s := newServer()
|
||||
hook := new(modifiedHookBase)
|
||||
hook.fail = true
|
||||
hook.err = packets.ErrPayloadFormatInvalid
|
||||
err := s.AddHook(hook, nil)
|
||||
require.NoError(t,err)
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Clients.Add(cl)
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPubackUnexpectedError).RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) {
|
||||
s := newServer()
|
||||
hook := new(modifiedHookBase)
|
||||
hook.fail = true
|
||||
hook.err = packets.CodeSuccessIgnore
|
||||
err := s.AddHook(hook, nil)
|
||||
require.NoError(t,err)
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
receiver, r2, w2 := newTestClient()
|
||||
receiver.ID = "receiver"
|
||||
s.Clients.Add(receiver)
|
||||
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
|
||||
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
|
||||
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
|
||||
|
||||
receiverBuf := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(r2)
|
||||
require.NoError(t, err)
|
||||
receiverBuf <- buf
|
||||
}()
|
||||
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
w2.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
|
||||
require.Equal(t, []byte{}, <-receiverBuf)
|
||||
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
|
||||
}
|
||||
|
||||
func TestServerProcessPacketPublishMaximumReceive(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Serve()
|
||||
@@ -1393,6 +1523,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
|
||||
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
|
||||
}
|
||||
|
||||
|
||||
func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newTestClient()
|
||||
@@ -1537,6 +1668,32 @@ func TestPublishToSubscribersIdentifiers(t *testing.T) {
|
||||
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishSubscriberIdentifier).RawBytes, <-receiverBuf)
|
||||
}
|
||||
|
||||
func TestPublishToSubscribersPkIgnore(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "#", Identifier: 1})
|
||||
require.True(t, subbed)
|
||||
|
||||
go func() {
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||
pk.Ignore = true
|
||||
s.publishToSubscribers(pk)
|
||||
time.Sleep(time.Millisecond)
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
receiverBuf := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
receiverBuf <- buf
|
||||
}()
|
||||
|
||||
require.Equal(t, []byte{}, <-receiverBuf)
|
||||
}
|
||||
|
||||
|
||||
func TestPublishToClientServerDowngradeQos(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Options.Capabilities.MaximumQos = 1
|
||||
@@ -1846,6 +2003,27 @@ func TestNoRetainMessageIfUnavailable(t *testing.T) {
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
|
||||
}
|
||||
|
||||
|
||||
func TestNoRetainMessageIfPkIgnore(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet
|
||||
pk.Ignore = true
|
||||
s.retainMessage(new(Client), pk)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
|
||||
}
|
||||
|
||||
func TestNoRetainMessage(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Retained))
|
||||
}
|
||||
|
||||
func TestServerProcessPacketPuback(t *testing.T) {
|
||||
tt := ProtocolTest{
|
||||
{
|
||||
|
Reference in New Issue
Block a user