diff --git a/server/server_test.go b/server/server_test.go index d878599..267612e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,6 +2,7 @@ package server import ( //"errors" + "fmt" "io/ioutil" "net" "strconv" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/mochi-co/mqtt/server/events" "github.com/mochi-co/mqtt/server/internal/circ" "github.com/mochi-co/mqtt/server/internal/clients" "github.com/mochi-co/mqtt/server/internal/packets" @@ -650,7 +652,7 @@ func TestServerProcessPublishQoS2(t *testing.T) { require.Equal(t, int64(0), atomic.LoadInt64(&s.System.Retained)) } -func TestServerProcessPublishUnretain(t *testing.T) { +func TestServerProcessPublishUnretainByEmptyPayload(t *testing.T) { s, cl1, r1, w1 := setupClient() s.Clients.Add(cl1) @@ -920,6 +922,154 @@ func TestServerPublishInlineSysTopicError(t *testing.T) { require.Equal(t, int64(0), s.System.BytesSent) } +func TestServerProcessPublishHookOnMessage(t *testing.T) { + s, cl1, r1, w1 := setupClient() + s.Clients.Add(cl1) + s.Topics.Subscribe("a/b/+", cl1.ID, 0) + + var hookedPacket events.Packet + var hookedClient events.Client + s.Events.OnMessage = func(cl events.Client, pk events.Packet) { + hookedPacket = pk + hookedClient = cl + } + + ack1 := make(chan []byte) + go func() { + buf, err := ioutil.ReadAll(r1) + if err != nil { + panic(err) + } + ack1 <- buf + }() + + pk1 := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + } + err := s.processPacket(cl1, pk1) + + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + + require.Equal(t, events.Client{ + ID: "mochi", + Listener: "", + }, hookedClient) + + require.Equal(t, events.Packet(pk1), hookedPacket) + + w1.Close() + + require.Equal(t, []byte{ + byte(packets.Publish << 4), 12, + 0, 5, + 'a', '/', 'b', '/', 'c', + 'h', 'e', 'l', 'l', 'o', + }, <-ack1) + + require.Equal(t, int64(14), s.System.BytesSent) +} + +func TestServerProcessPublishHookOnMessageModify(t *testing.T) { + s, cl1, r1, w1 := setupClient() + s.Clients.Add(cl1) + s.Topics.Subscribe("a/b/+", cl1.ID, 0) + + var hookedPacket events.Packet + var hookedClient events.Client + s.Events.OnMessageModify = func(cl events.Client, pk events.Packet) (events.Packet, error) { + hookedPacket = pk + hookedPacket.Payload = []byte("world") + hookedClient = cl + return hookedPacket, nil + } + + ack1 := make(chan []byte) + go func() { + buf, err := ioutil.ReadAll(r1) + if err != nil { + panic(err) + } + ack1 <- buf + }() + + pk1 := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + } + err := s.processPacket(cl1, pk1) + + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + + require.Equal(t, events.Client{ + ID: "mochi", + Listener: "", + }, hookedClient) + + w1.Close() + + require.Equal(t, []byte{ + byte(packets.Publish << 4), 12, + 0, 5, + 'a', '/', 'b', '/', 'c', + 'w', 'o', 'r', 'l', 'd', + }, <-ack1) + + require.Equal(t, int64(14), s.System.BytesSent) +} + +func TestServerProcessPublishHookOnMessageModifyError(t *testing.T) { + s, cl1, r1, w1 := setupClient() + s.Clients.Add(cl1) + s.Topics.Subscribe("a/b/+", cl1.ID, 0) + + s.Events.OnMessageModify = func(cl events.Client, pk events.Packet) (events.Packet, error) { + pkx := pk + pkx.Payload = []byte("world") + return pkx, fmt.Errorf("error") + } + + ack1 := make(chan []byte) + go func() { + buf, err := ioutil.ReadAll(r1) + if err != nil { + panic(err) + } + ack1 <- buf + }() + + pk1 := packets.Packet{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + } + err := s.processPacket(cl1, pk1) + + require.NoError(t, err) + time.Sleep(10 * time.Millisecond) + + w1.Close() + + require.Equal(t, []byte{ + byte(packets.Publish << 4), 12, + 0, 5, + 'a', '/', 'b', '/', 'c', + 'h', 'e', 'l', 'l', 'o', + }, <-ack1) + + require.Equal(t, int64(14), s.System.BytesSent) +} + func TestServerProcessPuback(t *testing.T) { s, cl, _, _ := setupClient() cl.Inflight.Set(11, clients.InflightMessage{Packet: packets.Packet{PacketID: 11}, Sent: 0})