Allow Publish to return custom Ack error responses (#256)

* Allow publish error returns as acks

* Add Ignore Packet, tests
This commit is contained in:
JB
2023-07-20 22:52:16 +01:00
committed by GitHub
parent 0234589152
commit ac812154e6
6 changed files with 282 additions and 11 deletions

View File

@@ -1198,6 +1198,50 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
}
func TestServerBuildAck(t *testing.T) {
s := newServer()
properties := packets.Properties{
User: []packets.UserProperty{
{Key: "hello", Val: "世界"},
},
}
ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
require.Equal(t, properties, ack.Properties)
}
func TestServerBuildAckError(t *testing.T) {
s := newServer()
properties := packets.Properties{
User: []packets.UserProperty{
{Key: "hello", Val: "世界"},
},
}
ack := s.buildAck(7, packets.Puback, 1, properties, packets.ErrMalformedPacket)
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
require.Equal(t, packets.ErrMalformedPacket.Code, ack.ReasonCode)
properties.ReasonString = packets.ErrMalformedPacket.Reason
require.Equal(t, properties, ack.Properties)
}
func TestServerBuildAckPahoCompatibility(t *testing.T) {
s := newServer()
s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
properties := packets.Properties{
User: []packets.UserProperty{
{Key: "hello", Val: "世界"},
},
}
ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
require.Equal(t, packets.Properties{}, ack.Properties)
}
func TestServerProcessPacketAndNextImmediate(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
@@ -1222,7 +1266,7 @@ func TestServerProcessPacketAndNextImmediate(t *testing.T) {
require.Equal(t, int32(4), cl.State.Inflight.sendQuota)
}
func TestServerProcessPacketPublishAckFailure(t *testing.T) {
func TestServerProcessPublishAckFailure(t *testing.T) {
s := newServer()
s.Serve()
defer s.Close()
@@ -1236,6 +1280,92 @@ func TestServerProcessPacketPublishAckFailure(t *testing.T) {
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
hook.err = packets.ErrUnspecifiedError
err := s.AddHook(hook, nil)
require.NoError(t,err)
cl, _, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl)
w.Close()
err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
hook.err = packets.ErrPayloadFormatInvalid
err := s.AddHook(hook, nil)
require.NoError(t,err)
s.Serve()
defer s.Close()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.NoError(t, err)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPubackUnexpectedError).RawBytes, buf)
}
func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
hook.err = packets.CodeSuccessIgnore
err := s.AddHook(hook, nil)
require.NoError(t,err)
s.Serve()
defer s.Close()
cl, r, w := newTestClient()
s.Clients.Add(cl)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.NoError(t, err)
w.Close()
w2.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
require.Equal(t, []byte{}, <-receiverBuf)
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
}
func TestServerProcessPacketPublishMaximumReceive(t *testing.T) {
s := newServer()
s.Serve()
@@ -1393,6 +1523,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
}
func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
@@ -1537,6 +1668,32 @@ func TestPublishToSubscribersIdentifiers(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishSubscriberIdentifier).RawBytes, <-receiverBuf)
}
func TestPublishToSubscribersPkIgnore(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "#", Identifier: 1})
require.True(t, subbed)
go func() {
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
pk.Ignore = true
s.publishToSubscribers(pk)
time.Sleep(time.Millisecond)
w.Close()
}()
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
receiverBuf <- buf
}()
require.Equal(t, []byte{}, <-receiverBuf)
}
func TestPublishToClientServerDowngradeQos(t *testing.T) {
s := newServer()
s.Options.Capabilities.MaximumQos = 1
@@ -1846,6 +2003,27 @@ func TestNoRetainMessageIfUnavailable(t *testing.T) {
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
}
func TestNoRetainMessageIfPkIgnore(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
s.Clients.Add(cl)
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet
pk.Ignore = true
s.retainMessage(new(Client), pk)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
}
func TestNoRetainMessage(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
s.Clients.Add(cl)
s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Retained))
}
func TestServerProcessPacketPuback(t *testing.T) {
tt := ProtocolTest{
{