From 80a163ece923ae3fe015c4f0967faed8b51f8469 Mon Sep 17 00:00:00 2001 From: Mochi Date: Sun, 6 Oct 2019 11:37:45 +0100 Subject: [PATCH] Process Publish+Recv --- mqtt.go | 106 +++++++++++++++++++++++++++++---------------------- mqtt_test.go | 47 +++++++++++++++++------ 2 files changed, 95 insertions(+), 58 deletions(-) diff --git a/mqtt.go b/mqtt.go index 042d063..426b0e6 100644 --- a/mqtt.go +++ b/mqtt.go @@ -245,42 +245,6 @@ DONE: return nil } -// writeClient writes packets to a client connection. -func (s *Server) writeClient(cl *client, pk packets.Packet) error { - - // Ensure Writer is open. - if cl.p.W == nil { - return ErrConnectionClosed - } - - // Encode packet to a pooled byte buffer. - buf := s.buffers.Get() - defer s.buffers.Put(buf) - err := pk.Encode(buf) - if err != nil { - return err - } - - // Write packet to client. - _, err = buf.WriteTo(cl.p.W) - if err != nil { - return err - } - - err = cl.p.W.Flush() - if err != nil { - return err - } - - // Refresh deadline to keep the connection alive. - cl.p.RefreshDeadline(cl.keepalive) - - // Log $SYS stats. - // @TODO ... - - return nil -} - // processPacket processes an inbound packet for a client. func (s *Server) processPacket(cl *client, pk packets.Packet) error { log.Println("PROCESSING PACKET", cl, pk) @@ -329,18 +293,29 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { log.Println(client, id, qos) // Make a copy of the packet to send to client. - outgoing := msg.Copy() - log.Println(outgoing) - // If the subscriber has a higher qos, inherit it. - /* if subscriptions.qos > outgoing.Qos { - outgoing.Qos = subscriptions.qos - } + out := msg.Copy() - // If QoS byte is set, ensure the message has an id. - if outgoing.Qos > 0 && outgoing.PacketID == 0 { - //outgoing.PacketID = client.nextPacketID() - }*/ + // If the client subscription has a higher qos, inherit it. + if qos > out.Qos { + out.Qos = qos + } + // If QoS byte is set, ensure the message has an id. + if out.Qos > 0 && out.PacketID == 0 { + out.PacketID = uint16(client.nextPacketID()) + } + + // Write the publish packet out to the receiving client. + err := s.writeClient(client, out) + if err != nil { + s.closeClient(client, true) + } + + // If QoS byte is set, save as message to inflight index so we + // can track delivery. + if out.Qos > 0 { + // client.handleQOS(out) + } } } @@ -383,6 +358,45 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { } */ +// writeClient writes packets to a client connection. +func (s *Server) writeClient(cl *client, pk packets.Packet) error { + + // Ensure Writer is open. + if cl.p.W == nil { + return ErrConnectionClosed + } + + // Encode packet to a pooled byte buffer. + buf := s.buffers.Get() + defer s.buffers.Put(buf) + err := pk.Encode(buf) + if err != nil { + return err + } + + log.Println("==", buf.Bytes()) + + // Write packet to client. + _, err = buf.WriteTo(cl.p.W) + if err != nil { + return err + } + + err = cl.p.W.Flush() + if err != nil { + return err + } + log.Println("WRITE CLIENT", cl.id) + + // Refresh deadline to keep the connection alive. + cl.p.RefreshDeadline(cl.keepalive) + + // Log $SYS stats. + // @TODO ... + + return nil +} + // closeClient closes a client connection and publishes any LWT messages. func (s *Server) closeClient(cl *client, sendLWT bool) error { diff --git a/mqtt_test.go b/mqtt_test.go index 095388a..60f2207 100644 --- a/mqtt_test.go +++ b/mqtt_test.go @@ -601,18 +601,23 @@ func TestServerProcessPacketPINGClose(t *testing.T) { func TestServerProcessPacketPublishOK(t *testing.T) { s := New() + // Sender r, w := net.Pipe() - p := packets.NewParser(r, newBufioReader(r), newBufioWriter(w)) - c1 := newClient(p, &packets.ConnectPacket{ClientIdentifier: "c1"}) + c1 := newClient( + packets.NewParser(r, newBufioReader(r), newBufioWriter(w)), + &packets.ConnectPacket{ClientIdentifier: "c1"}, + ) s.clients.add(c1) - s.topics.Subscribe("a/b/c", c1.id, 0) - s.topics.Subscribe("a/+/c", c1.id, 1) + // Subscriber r2, w2 := net.Pipe() - p2 := packets.NewParser(r2, newBufioReader(r2), newBufioWriter(w2)) - c2 := newClient(p2, &packets.ConnectPacket{ClientIdentifier: "c2"}) + c2 := newClient( + packets.NewParser(r2, newBufioReader(r2), newBufioWriter(w2)), + &packets.ConnectPacket{ClientIdentifier: "c2"}, + ) s.clients.add(c2) s.topics.Subscribe("a/b/+", c2.id, 0) + s.topics.Subscribe("a/+/c", c2.id, 1) o := make(chan error, 2) go func() { @@ -623,16 +628,34 @@ func TestServerProcessPacketPublishOK(t *testing.T) { TopicName: "a/b/c", Payload: []byte("hello"), }) + r.Close() + w.Close() + w2.Close() }() - time.Sleep(10 * time.Millisecond) - require.NoError(t, <-o) - close(o) + recv := make(chan []byte) + go func() { + buf, err := ioutil.ReadAll(r2) + if err != nil { + panic(err) + } + recv <- buf + }() - w.Close() - r.Close() + require.NoError(t, <-o) + require.Equal(t, + []byte{ + byte(packets.Publish<<4 | 2), 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 1, // packet id from qos=1 + 'h', 'e', 'l', 'l', 'o', // Payload + }, + <-recv, + ) + close(o) + close(recv) r2.Close() - w2.Close() } func TestServerProcessPacketPublishRetain(t *testing.T) {