From 240187ffa2aadafa12fdb963257ddd733e98f0c2 Mon Sep 17 00:00:00 2001 From: Mochi Date: Sat, 2 Nov 2019 22:04:15 +0000 Subject: [PATCH] More tests, connecting MQTT --- mqtt.go | 7 +--- mqtt_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++---- processor.go | 1 + 3 files changed, 110 insertions(+), 13 deletions(-) diff --git a/mqtt.go b/mqtt.go index 5353b52..35e9a15 100644 --- a/mqtt.go +++ b/mqtt.go @@ -229,7 +229,6 @@ DONE: cl.p.RefreshDeadline(cl.keepalive) // Read in the fixed header of the packet. - //fh := new(packets.FixedHeader) err = cl.p.ReadFixedHeader(fh) if err != nil { return ErrReadFixedHeader @@ -245,12 +244,10 @@ DONE: pk, err = cl.p.Read() if err != nil { fmt.Println("RC READ ERR", err) - //return ErrReadPacketPayload - return nil + return ErrReadPacketPayload + //return nil } - fmt.Println("READ PACKET", pk) - // Validate the packet if necessary. _, err = pk.Validate() if err != nil { diff --git a/mqtt_test.go b/mqtt_test.go index b9dacc7..e7c6421 100644 --- a/mqtt_test.go +++ b/mqtt_test.go @@ -54,13 +54,10 @@ func (q *quietWriter) Flush() error { func setupClient(id string) (s *Server, r net.Conn, w net.Conn, cl *client) { r, w = net.Pipe() s = New() - cl = newClient( // new(MockNetConn) - NewProcessor(r, circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize)), - &packets.ConnectPacket{ - ClientIdentifier: id, - }, - new(auth.Allow), - ) + p := NewProcessor(r, circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize)) + p.Start() + cl = newClient(p, &packets.ConnectPacket{ClientIdentifier: id}, new(auth.Allow)) + return } @@ -475,6 +472,78 @@ func TestServerReadClientOK(t *testing.T) { 'a', '/', 'b', '/', 'c', // Topic Name 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload }) + }() + + o := make(chan error) + go func() { + time.Sleep(time.Millisecond * 10) + o <- s.readClient(cl) + }() + + w.Close() + r.Close() + require.NoError(t, <-o) +} + +func TestServerReadClientNoConn(t *testing.T) { + s, r, _, cl := setupClient("zen") + cl.p.Conn.Close() + cl.p.Conn = nil + r.Close() + + err := s.readClient(cl) + require.Error(t, err) + require.Equal(t, ErrConnectionClosed, err) +} + +func TestServerReadClientBadFixedHeader(t *testing.T) { + s, r, w, cl := setupClient("zen") + + go func() { + close(cl.end) + w.Write([]byte{99}) + }() + + o := make(chan error) + go func() { + time.Sleep(time.Millisecond * 10) + o <- s.readClient(cl) + }() + + w.Close() + r.Close() + require.NoError(t, <-o) +} + +func TestServerReadClientDisconnect(t *testing.T) { + s, r, w, cl := setupClient("zen") + + go func() { + w.Write([]byte{packets.Disconnect << 4, 0}) + close(cl.end) + }() + + o := make(chan error) + go func() { + time.Sleep(time.Millisecond * 10) + o <- s.readClient(cl) + }() + + w.Close() + r.Close() + require.NoError(t, <-o) +} + +func TestServerReadClientBadPacketPayload(t *testing.T) { + s, r, w, cl := setupClient("zen") + + go func() { + w.Write([]byte{ + byte(packets.Publish << 4), 6, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', + 0, 11, // Packet ID - LSB+MSB, // malformed packet id. + }) }() @@ -483,9 +552,39 @@ func TestServerReadClientOK(t *testing.T) { o <- s.readClient(cl) }() + err := <-o + require.Error(t, err) + require.Equal(t, ErrReadPacketPayload, err) + w.Close() + r.Close() + +} + +func TestServerReadClientBadPacketValidation(t *testing.T) { + s, r, w, cl := setupClient("zen") + + go func() { + w.Write([]byte{ + byte(packets.Publish<<4) | 2, 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 0, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }) + + }() + + o := make(chan error) + go func() { + o <- s.readClient(cl) + }() + + err := <-o + require.Error(t, err) + require.Equal(t, ErrReadPacketValidation, err) + w.Close() r.Close() - require.NoError(t, <-o) } /* diff --git a/processor.go b/processor.go index 4ca1c1e..071d1be 100644 --- a/processor.go +++ b/processor.go @@ -72,6 +72,7 @@ func (p *Processor) Start() { p.R.ReadFrom(p.Conn) }(p.started, p.endedR) p.endedR.Add(1) + fmt.Println("Starting processor...") p.started.Wait() fmt.Println("Started OK")