diff --git a/clients.go b/clients.go index e705b0b..9a9b355 100644 --- a/clients.go +++ b/clients.go @@ -104,7 +104,7 @@ func newClient(p *packets.Parser, pk *packets.ConnectPacket) *client { keepalive: pk.Keepalive, cleanSession: pk.CleanSession, inFlight: inFlight{ - internal: make(map[uint16]*inFlightMessage), + internal: make(map[uint16]*inFlightMessage, 2), }, } diff --git a/mqtt.go b/mqtt.go index 38cadcc..3413bd1 100644 --- a/mqtt.go +++ b/mqtt.go @@ -3,6 +3,7 @@ package mqtt import ( "bufio" "errors" + "log" "net" "github.com/mochi-co/mqtt/auth" @@ -261,6 +262,7 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { }) if err != nil { s.closeClient(cl, true) + return err } case *packets.PublishPacket: @@ -295,6 +297,7 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { err := s.writeClient(client, out) if err != nil { s.closeClient(client, true) + return err } // If QoS byte is set, save as message to inflight index so we @@ -318,6 +321,7 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { err := s.writeClient(cl, out) if err != nil { s.closeClient(cl, true) + return err } cl.inFlight.set(out.PacketID, out) @@ -333,6 +337,7 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { err := s.writeClient(cl, out) if err != nil { s.closeClient(cl, true) + return err } cl.inFlight.delete(msg.PacketID) @@ -358,8 +363,10 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { PacketID: msg.PacketID, ReturnCodes: retCodes, }) + if err != nil { s.closeClient(cl, true) + return err } // Publish out any retained messages matching the subscription filter. @@ -368,7 +375,9 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { for _, pkv := range messages { err := s.writeClient(cl, pkv) if err != nil { + log.Println("write err B", err) s.closeClient(cl, true) + return err } } } @@ -385,6 +394,7 @@ func (s *Server) processPacket(cl *client, pk packets.Packet) error { }) if err != nil { s.closeClient(cl, true) + return err } } diff --git a/mqtt_test.go b/mqtt_test.go index a802310..60a76ac 100644 --- a/mqtt_test.go +++ b/mqtt_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "log" "net" + "strings" "testing" "time" @@ -549,9 +550,38 @@ func TestServerWriteClientNilWriter(t *testing.T) { } func TestServerWriteClientWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + s.clients.add(cl) + // + err := s.writeClient(cl, &packets.PublishPacket{}) + require.Error(t, err) + w.Close() + r.Close() } +/* +func TestServerProcessPacketSubscribeWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + + o := make(chan error, 2) + go func() { + r.Close() + err := s.processPacket(cl, &packets.SubscribePacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Subscribe, + }, + PacketID: 10, + }) + o <- err + }() + + require.Error(t, <-o) + close(o) + w.Close() +} +*/ + /* * Server Close Client @@ -645,7 +675,7 @@ func TestServerProcessPacketPingOK(t *testing.T) { r.Close() } -func TestServerProcessPacketPingClose(t *testing.T) { +func TestServerProcessPacketPingWriteError(t *testing.T) { s, r, w, cl := setupClient("zen") o := make(chan error, 2) @@ -659,7 +689,7 @@ func TestServerProcessPacketPingClose(t *testing.T) { }) }() - require.NoError(t, <-o) + require.Error(t, <-o) require.Nil(t, cl.p.Conn) close(o) @@ -744,6 +774,32 @@ func TestServerProcessPacketPublishRetain(t *testing.T) { r.Close() } +func TestServerProcessPacketPublishWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + + s.clients.add(cl) + s.topics.Subscribe("a/+/c", cl.id, 0) + require.Nil(t, cl.inFlight.internal[1]) + + o := make(chan error, 2) + go func() { + r.Close() + err := s.processPacket(cl, &packets.PublishPacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + }) + + o <- err + }() + + require.Error(t, <-o) + close(o) + w.Close() +} + func TestServerProcessPacketPuback(t *testing.T) { s, r, _, cl := setupClient("zen") @@ -803,6 +859,28 @@ func TestServerProcessPacketPubrec(t *testing.T) { r.Close() } +func TestServerProcessPacketPubrecWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + + cl.inFlight.set(10, &packets.PublishPacket{PacketID: 10}) + + o := make(chan error, 2) + go func() { + r.Close() + err := s.processPacket(cl, &packets.PubrecPacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Pubrec, + }, + PacketID: 10, + }) + o <- err + }() + + require.Error(t, <-o) + close(o) + w.Close() +} + func TestServerProcessPacketPubrel(t *testing.T) { s, r, w, cl := setupClient("zen") @@ -834,6 +912,28 @@ func TestServerProcessPacketPubrel(t *testing.T) { r.Close() } +func TestServerProcessPacketPubrelWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + + cl.inFlight.set(10, &packets.PublishPacket{PacketID: 10}) + + o := make(chan error, 2) + go func() { + r.Close() + err := s.processPacket(cl, &packets.PubrelPacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Pubrel, + }, + PacketID: 10, + }) + o <- err + }() + + require.Error(t, <-o) + close(o) + w.Close() +} + func TestServerProcessPacketPubcomp(t *testing.T) { s, r, _, cl := setupClient("zen") @@ -945,6 +1045,64 @@ func TestServerProcessPacketSubscribeRetained(t *testing.T) { r.Close() } +func TestServerProcessPacketSubscribeRetainedWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + + s.topics.RetainMessage(&packets.PublishPacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + Retain: true, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + }) + + require.Equal(t, 1, len(s.topics.Messages("a/b/c"))) + + o := make(chan error, 2) + go func() { + o <- s.processPacket(cl, &packets.SubscribePacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Subscribe, + }, + PacketID: 10, + Topics: []string{"a/b/c", "d/e/f"}, + Qoss: []byte{0, 1}, + }) + w.Close() + }() + + buf := make([]byte, 4) + for i := 0; i < 4; i++ { + r.Read(buf) + } + r.Close() + + require.Error(t, <-o) + close(o) + r.Close() +} + +func TestServerProcessPacketSubscribeWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + + o := make(chan error, 2) + go func() { + r.Close() + err := s.processPacket(cl, &packets.SubscribePacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Subscribe, + }, + PacketID: 10, + }) + o <- err + }() + + require.Error(t, <-o) + close(o) + w.Close() +} + func TestServerProcessPacketUnsubscribe(t *testing.T) { s, r, w, cl := setupClient("zen") @@ -980,3 +1138,23 @@ func TestServerProcessPacketUnsubscribe(t *testing.T) { require.Empty(t, s.topics.Subscribers("a/b/c")) require.Empty(t, s.topics.Subscribers("d/e/f")) } + +func TestServerProcessPacketUnsubscribeWriteError(t *testing.T) { + s, r, w, cl := setupClient("zen") + + o := make(chan error, 2) + go func() { + r.Close() + err := s.processPacket(cl, &packets.UnsubscribePacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Unsubscribe, + }, + PacketID: 10, + }) + o <- err + }() + + require.Error(t, <-o) + close(o) + w.Close() +}