Files
mochi-mqtt/mqtt_test.go
2019-11-29 21:33:22 +00:00

432 lines
8.7 KiB
Go

package mqtt
import (
"io/ioutil"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/internal/auth"
"github.com/mochi-co/mqtt/internal/circ"
"github.com/mochi-co/mqtt/internal/clients"
"github.com/mochi-co/mqtt/internal/listeners"
"github.com/mochi-co/mqtt/internal/packets"
)
/*
* Server
*/
func TestNew(t *testing.T) {
s := New()
require.NotNil(t, s)
require.NotNil(t, s.Listeners)
require.NotNil(t, s.Clients)
require.NotNil(t, s.Topics)
}
func BenchmarkNew(b *testing.B) {
for n := 0; n < b.N; n++ {
New()
}
}
func TestServerAddListener(t *testing.T) {
s := New()
require.NotNil(t, s)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"), nil)
require.NoError(t, err)
// Add listener with config.
err = s.AddListener(listeners.NewMockListener("t2", ":1882"), &listeners.Config{
Auth: new(auth.Disallow),
})
require.NoError(t, err)
l, ok := s.Listeners.Get("t2")
require.Equal(t, true, ok)
require.Equal(t, new(auth.Disallow), l.(*listeners.MockListener).Config.Auth)
// Add listener on existing id
err = s.AddListener(listeners.NewMockListener("t1", ":1883"), nil)
require.Error(t, err)
require.Equal(t, ErrListenerIDExists, err)
}
func BenchmarkServerAddListener(b *testing.B) {
s := New()
l := listeners.NewMockListener("t1", ":1882")
for n := 0; n < b.N; n++ {
err := s.AddListener(l, nil)
if err != nil {
panic(err)
}
s.Listeners.Delete("t1")
}
}
func TestServerServe(t *testing.T) {
s := New()
require.NotNil(t, s)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"), nil)
require.NoError(t, err)
s.Serve()
time.Sleep(time.Millisecond)
require.Equal(t, 1, s.Listeners.Len())
listener, ok := s.Listeners.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, true, listener.(*listeners.MockListener).IsServing)
}
func BenchmarkServerServe(b *testing.B) {
s := New()
l := listeners.NewMockListener("t1", ":1882")
err := s.AddListener(l, nil)
if err != nil {
panic(err)
}
for n := 0; n < b.N; n++ {
s.Serve()
}
}
/*
* Server Establish Connection
*/
func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
s := New()
// Existing conneciton with subscription.
c, _ := net.Pipe()
cl := clients.NewClient(c, circ.NewReader(256, 8), circ.NewWriter(256, 8))
cl.ID = "mochi"
cl.Subscriptions = map[string]byte{
"a/b/c": 1,
}
s.Clients.Add(cl)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
require.Empty(t, clw.Subscriptions)
w.Close()
}
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
s := New()
c, _ := net.Pipe()
cl := clients.NewClient(c, circ.NewReader(256, 8), circ.NewWriter(256, 8))
cl.ID = "mochi"
cl.Subscriptions = map[string]byte{
"a/b/c": 1,
}
s.Clients.Add(cl)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Packet Flags
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
1, packets.Accepted,
}, <-recv)
require.NotEmpty(t, clw.Subscriptions)
w.Close()
}
func TestServerEstablishConnectionBadFixedHeader(t *testing.T) {
s := New()
r, w := net.Pipe()
go func() {
w.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
w.Close()
}()
err := s.EstablishConnection("tcp", r, new(auth.Allow))
r.Close()
require.Error(t, err)
require.Equal(t, packets.ErrInvalidFlags, err)
}
func TestServerEstablishConnectionInvalidPacket(t *testing.T) {
s := New()
r, w := net.Pipe()
go func() {
w.Write([]byte{0, 0})
w.Close()
}()
err := s.EstablishConnection("tcp", r, new(auth.Allow))
r.Close()
require.Error(t, err)
}
func TestServerEstablishConnectionNotConnectPacket(t *testing.T) {
s := New()
r, w := net.Pipe()
go func() {
w.Write([]byte{byte(packets.Connack << 4), 2, 0, packets.Accepted})
w.Close()
}()
err := s.EstablishConnection("tcp", r, new(auth.Allow))
r.Close()
require.Error(t, err)
require.Equal(t, ErrReadConnectInvalid, err)
}
func TestServerEstablishConnectionBadAuth(t *testing.T) {
s := New()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Disallow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 30, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
194, // Packet Flags
0, 20, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
0, 5, // Username MSB+LSB
'm', 'o', 'c', 'h', 'i',
0, 4, // Password MSB+LSB
'a', 'b', 'c', 'd',
})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
errx := <-o
time.Sleep(time.Millisecond)
r.Close()
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.CodeConnectBadAuthValues,
}, <-recv)
}
func TestServerEstablishConnectionPromptSendLWT(t *testing.T) {
s := New()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID,
})
w.Write([]byte{0, 0}) // invalid packet
}()
// Receive the Connack
go func() {
_, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
}()
require.Error(t, <-o)
}
func TestWriteClient(t *testing.T) {
s := New()
r, w := net.Pipe()
cl := clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8))
cl.ID = "mochi"
cl.Start()
defer cl.Stop()
err := s.writeClient(cl, &packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Pubcomp,
Remaining: 2,
},
PacketID: 14,
})
require.NoError(t, err)
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r)
if err != nil {
panic(err)
}
recv <- buf
}()
time.Sleep(time.Millisecond)
w.Close()
require.Equal(t, []byte{
byte(packets.Pubcomp << 4), 2, // Fixed header
0, 14, // Packet ID - LSB+MSB
}, <-recv)
}
func TestWriteClientError(t *testing.T) {
s := New()
w, _ := net.Pipe()
cl := clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8))
cl.ID = "mochi"
err := s.writeClient(cl, new(packets.Packet))
require.Error(t, err)
}
/*
func TestResendInflight(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.inFlight.set(1, &inFlightMessage{
packet: &packets.PublishPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
Retain: true,
Dup: true,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 1,
},
sent: time.Now().Unix(),
})
err := s.resendInflight(cl)
require.NoError(t, err)
require.Equal(t, []byte{
byte(packets.Publish<<4 | 11), 14, // Fixed header QoS : 1
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
0, 1, // packet id from qos=1
'h', 'e', 'l', 'l', 'o', // Payload)
}, cl.p.W.Get()[:16])
}
func TestResendInflightWriteError(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.inFlight.set(1, &inFlightMessage{
packet: &packets.PublishPacket{},
})
cl.p.W.Close()
err := s.resendInflight(cl)
require.Error(t, err)
}
*/