Process Publish+Recv

This commit is contained in:
Mochi
2019-10-06 11:37:45 +01:00
parent dd8945a880
commit 80a163ece9
2 changed files with 95 additions and 58 deletions

106
mqtt.go
View File

@@ -245,42 +245,6 @@ DONE:
return nil
}
// writeClient writes packets to a client connection.
func (s *Server) writeClient(cl *client, pk packets.Packet) error {
// Ensure Writer is open.
if cl.p.W == nil {
return ErrConnectionClosed
}
// Encode packet to a pooled byte buffer.
buf := s.buffers.Get()
defer s.buffers.Put(buf)
err := pk.Encode(buf)
if err != nil {
return err
}
// Write packet to client.
_, err = buf.WriteTo(cl.p.W)
if err != nil {
return err
}
err = cl.p.W.Flush()
if err != nil {
return err
}
// Refresh deadline to keep the connection alive.
cl.p.RefreshDeadline(cl.keepalive)
// Log $SYS stats.
// @TODO ...
return nil
}
// processPacket processes an inbound packet for a client.
func (s *Server) processPacket(cl *client, pk packets.Packet) error {
log.Println("PROCESSING PACKET", cl, pk)
@@ -329,18 +293,29 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error {
log.Println(client, id, qos)
// Make a copy of the packet to send to client.
outgoing := msg.Copy()
log.Println(outgoing)
// If the subscriber has a higher qos, inherit it.
/* if subscriptions.qos > outgoing.Qos {
outgoing.Qos = subscriptions.qos
}
out := msg.Copy()
// If QoS byte is set, ensure the message has an id.
if outgoing.Qos > 0 && outgoing.PacketID == 0 {
//outgoing.PacketID = client.nextPacketID()
}*/
// If the client subscription has a higher qos, inherit it.
if qos > out.Qos {
out.Qos = qos
}
// If QoS byte is set, ensure the message has an id.
if out.Qos > 0 && out.PacketID == 0 {
out.PacketID = uint16(client.nextPacketID())
}
// Write the publish packet out to the receiving client.
err := s.writeClient(client, out)
if err != nil {
s.closeClient(client, true)
}
// If QoS byte is set, save as message to inflight index so we
// can track delivery.
if out.Qos > 0 {
// client.handleQOS(out)
}
}
}
@@ -383,6 +358,45 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error {
}
*/
// writeClient writes packets to a client connection.
func (s *Server) writeClient(cl *client, pk packets.Packet) error {
// Ensure Writer is open.
if cl.p.W == nil {
return ErrConnectionClosed
}
// Encode packet to a pooled byte buffer.
buf := s.buffers.Get()
defer s.buffers.Put(buf)
err := pk.Encode(buf)
if err != nil {
return err
}
log.Println("==", buf.Bytes())
// Write packet to client.
_, err = buf.WriteTo(cl.p.W)
if err != nil {
return err
}
err = cl.p.W.Flush()
if err != nil {
return err
}
log.Println("WRITE CLIENT", cl.id)
// Refresh deadline to keep the connection alive.
cl.p.RefreshDeadline(cl.keepalive)
// Log $SYS stats.
// @TODO ...
return nil
}
// closeClient closes a client connection and publishes any LWT messages.
func (s *Server) closeClient(cl *client, sendLWT bool) error {

View File

@@ -601,18 +601,23 @@ func TestServerProcessPacketPINGClose(t *testing.T) {
func TestServerProcessPacketPublishOK(t *testing.T) {
s := New()
// Sender
r, w := net.Pipe()
p := packets.NewParser(r, newBufioReader(r), newBufioWriter(w))
c1 := newClient(p, &packets.ConnectPacket{ClientIdentifier: "c1"})
c1 := newClient(
packets.NewParser(r, newBufioReader(r), newBufioWriter(w)),
&packets.ConnectPacket{ClientIdentifier: "c1"},
)
s.clients.add(c1)
s.topics.Subscribe("a/b/c", c1.id, 0)
s.topics.Subscribe("a/+/c", c1.id, 1)
// Subscriber
r2, w2 := net.Pipe()
p2 := packets.NewParser(r2, newBufioReader(r2), newBufioWriter(w2))
c2 := newClient(p2, &packets.ConnectPacket{ClientIdentifier: "c2"})
c2 := newClient(
packets.NewParser(r2, newBufioReader(r2), newBufioWriter(w2)),
&packets.ConnectPacket{ClientIdentifier: "c2"},
)
s.clients.add(c2)
s.topics.Subscribe("a/b/+", c2.id, 0)
s.topics.Subscribe("a/+/c", c2.id, 1)
o := make(chan error, 2)
go func() {
@@ -623,16 +628,34 @@ func TestServerProcessPacketPublishOK(t *testing.T) {
TopicName: "a/b/c",
Payload: []byte("hello"),
})
r.Close()
w.Close()
w2.Close()
}()
time.Sleep(10 * time.Millisecond)
require.NoError(t, <-o)
close(o)
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r2)
if err != nil {
panic(err)
}
recv <- buf
}()
w.Close()
r.Close()
require.NoError(t, <-o)
require.Equal(t,
[]byte{
byte(packets.Publish<<4 | 2), 14, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
0, 1, // packet id from qos=1
'h', 'e', 'l', 'l', 'o', // Payload
},
<-recv,
)
close(o)
close(recv)
r2.Close()
w2.Close()
}
func TestServerProcessPacketPublishRetain(t *testing.T) {