More Tests

This commit is contained in:
Mochi
2019-11-02 16:46:21 +00:00
parent 21fe6d6bf4
commit 5d0aa0f1b5
7 changed files with 174 additions and 346 deletions

View File

@@ -80,7 +80,12 @@ func newBuffer(size, block int64) buffer {
func (b *buffer) Close() { func (b *buffer) Close() {
atomic.StoreInt64(&b.done, 1) atomic.StoreInt64(&b.done, 1)
debug.Println(b.id, "[B] STORE done=1 buffer closing") debug.Println(b.id, "[B] STORE done=1 buffer closing")
b.Bump()
debug.Println(b.id, "[B] DONE REBROADCASTED")
}
// Bump will broadcast all sync conditions.
func (b *buffer) Bump() {
debug.Println(b.id, "##### [X] wcond.Locking") debug.Println(b.id, "##### [X] wcond.Locking")
b.wcond.L.Lock() b.wcond.L.Lock()
debug.Println(b.id, "##### [X] wcond.Locked") debug.Println(b.id, "##### [X] wcond.Locked")
@@ -100,23 +105,28 @@ func (b *buffer) Close() {
debug.Println(b.id, "##### [Y] rcond.Unlocking") debug.Println(b.id, "##### [Y] rcond.Unlocking")
b.rcond.L.Unlock() b.rcond.L.Unlock()
debug.Println(b.id, "##### [Y] rcond.Unlocked") debug.Println(b.id, "##### [Y] rcond.Unlocked")
debug.Println(b.id, "[B] DONE REBROADCASTED")
} }
// Get will return the tail and head positions of the buffer. // Get will return the tail and head positions of the buffer.
// This method is for use with testing.
func (b *buffer) GetPos() (int64, int64) { func (b *buffer) GetPos() (int64, int64) {
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head) return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
} }
// SetPos sets the head and tail of the buffer. This method should only be // Get returns the internal buffer. This method is for use with testing.
// used for testing. // This method is for use with testing.
func (b *buffer) Get() []byte {
return b.buf
}
// SetPos sets the head and tail of the buffer.
// This method is for use with testing.
func (b *buffer) SetPos(tail, head int64) { func (b *buffer) SetPos(tail, head int64) {
b.tail, b.head = tail, head b.tail, b.head = tail, head
} }
// Set writes bytes to a byte buffer. This method should only be used for testing // Set writes bytes to a byte buffer.
// and will panic if out of range. // This method is for use with testing and will panic if out of range.
func (b *buffer) Set(p []byte, start, end int) { func (b *buffer) Set(p []byte, start, end int) {
o := 0 o := 0
for i := start; i < end; i++ { for i := start; i < end; i++ {

View File

@@ -195,17 +195,14 @@ func (cl *client) close() {
cl.done.Do(func() { cl.done.Do(func() {
if cl.end != nil { if cl.end != nil {
close(cl.end) // Signal to stop listening for packets in mqtt readClient loop. close(cl.end) // Signal to stop listening for packets in mqtt readClient loop.
fmt.Println("closed readClient end") fmt.Println("closed readClient end", cl.end)
} }
// Close the processor. // Close the processor.
cl.p.Stop() cl.p.Stop()
fmt.Println("signalled stop, waiting") fmt.Println("signalled stop, waiting")
fmt.Println("all processors ended") fmt.Println("all processors ended")
cl.wasClosed = true
// Close the network connection.
//cl.p.Conn.Close() // Error is irrelevant so can be ommitted here.
//cl.p.Conn = nil
}) })
} }

View File

@@ -251,7 +251,7 @@ func TestClientClose(t *testing.T) {
case _, ok = <-client.end: case _, ok = <-client.end:
} }
require.Equal(t, false, ok) require.Equal(t, false, ok)
require.Nil(t, client.p.Conn) require.Equal(t, true, client.wasClosed)
} }
func TestInFlightSet(t *testing.T) { func TestInFlightSet(t *testing.T) {

109
mqtt.go
View File

@@ -27,7 +27,6 @@ var (
ErrListenerIDExists = errors.New("Listener id already exists") ErrListenerIDExists = errors.New("Listener id already exists")
ErrReadConnectFixedHeader = errors.New("Error reading fixed header on CONNECT packet") ErrReadConnectFixedHeader = errors.New("Error reading fixed header on CONNECT packet")
ErrReadConnectPacket = errors.New("Error reading CONNECT packet") ErrReadConnectPacket = errors.New("Error reading CONNECT packet")
ErrFirstPacketInvalid = errors.New("First packet was not CONNECT packet")
ErrReadConnectInvalid = errors.New("CONNECT packet was not valid") ErrReadConnectInvalid = errors.New("CONNECT packet was not valid")
ErrConnectNotAuthorized = errors.New("CONNECT packet was not authorized") ErrConnectNotAuthorized = errors.New("CONNECT packet was not authorized")
ErrReadFixedHeader = errors.New("Error reading fixed header") ErrReadFixedHeader = errors.New("Error reading fixed header")
@@ -118,7 +117,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
// Ensure first packet is a connect packet. // Ensure first packet is a connect packet.
msg, ok := pk.(*packets.ConnectPacket) msg, ok := pk.(*packets.ConnectPacket)
if !ok { if !ok {
return ErrFirstPacketInvalid return ErrReadConnectInvalid
} }
// Ensure the packet conforms to MQTT CONNECT specifications. // Ensure the packet conforms to MQTT CONNECT specifications.
@@ -167,111 +166,11 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
ReturnCode: retcode, ReturnCode: retcode,
}) })
if err != nil { if err != nil {
fmt.Println("WRITECLIENT err", err)
return err return err
} }
fmt.Println("----------------") fmt.Println("----------------")
// Resend any unacknowledged QOS messages still pending for the client.
/*err = s.resendInflight(client)
if err != nil {
return err
}*/
// Block and listen for more packets, and end if an error or nil packet occurs.
var sendLWT bool
err = s.readClient(client)
if err != nil {
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
}
fmt.Println("** stop and wait", err)
// Publish last will and testament.
s.closeClient(client, sendLWT)
fmt.Println("stopped")
return err // capture err or nil from readClient.
}
/*
// EstablishConnection establishes a new client connection with the broker.
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
// Create a new packets parser which will parse all packets for this client,
// using buffered writers and readers.
p := NewParser(c, bufio.NewReaderSize(c, rwBufSize), bufio.NewWriterSize(c, rwBufSize))
// Pull the header from the first packet and check for a CONNECT message.
fh := new(packets.FixedHeader)
err := p.ReadFixedHeader(fh)
if err != nil {
return ErrReadConnectFixedHeader
}
// Read the first packet expecting a CONNECT message.
pk, err := p.Read()
if err != nil {
return ErrReadConnectPacket
}
// Ensure first packet is a connect packet.
msg, ok := pk.(*packets.ConnectPacket)
if !ok {
return ErrFirstPacketInvalid
}
// Ensure the packet conforms to MQTT CONNECT specifications.
retcode, _ := msg.Validate()
if retcode != packets.Accepted {
return ErrReadConnectInvalid
}
// If a username and password has been provided, perform authentication.
if msg.Username != "" && !ac.Authenticate(msg.Username, msg.Password) {
retcode = packets.CodeConnectNotAuthorised
return ErrConnectNotAuthorized
}
// Create a new client with this connection.
client := newClient(p, msg, ac)
client.listener = lid // Note the listener the client is connected to.
// If it's not a clean session and a client already exists with the same
// id, inherit the session.
var sessionPresent bool
if existing, ok := s.clients.get(msg.ClientIdentifier); ok {
existing.Lock()
existing.close()
if msg.CleanSession {
for k := range existing.subscriptions {
s.topics.Unsubscribe(k, existing.id)
}
} else {
client.inFlight = existing.inFlight // Inherit from existing session.
client.subscriptions = existing.subscriptions
sessionPresent = true
}
existing.Unlock()
}
// Add the new client to the clients manager.
s.clients.add(client)
// Send a CONNACK back to the client.
err = s.writeClient(client, &packets.ConnackPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Connack,
},
SessionPresent: sessionPresent,
ReturnCode: retcode,
})
if err != nil {
return err
}
// Resend any unacknowledged QOS messages still pending for the client. // Resend any unacknowledged QOS messages still pending for the client.
err = s.resendInflight(client) err = s.resendInflight(client)
if err != nil { if err != nil {
@@ -285,12 +184,15 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3] sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
} }
fmt.Println("** stop and wait", err)
// Publish last will and testament. // Publish last will and testament.
s.closeClient(client, sendLWT) s.closeClient(client, sendLWT)
fmt.Println("stopped", err)
return err // capture err or nil from readClient. return err // capture err or nil from readClient.
} }
*/
// resendInflight republishes any inflight messages to the client. // resendInflight republishes any inflight messages to the client.
func (s *Server) resendInflight(cl *client) error { func (s *Server) resendInflight(cl *client) error {
@@ -312,7 +214,6 @@ func (s *Server) readClient(cl *client) error {
var err error var err error
var pk packets.Packet var pk packets.Packet
fh := new(packets.FixedHeader) fh := new(packets.FixedHeader)
fmt.Println("STARTING READ CLIENT LOOP")
DONE: DONE:
for { for {
select { select {

View File

@@ -52,9 +52,10 @@ func (q *quietWriter) Flush() error {
} }
func setupClient(id string) (s *Server, r net.Conn, w net.Conn, cl *client) { func setupClient(id string) (s *Server, r net.Conn, w net.Conn, cl *client) {
r, w = net.Pipe()
s = New() s = New()
cl = newClient( cl = newClient( // new(MockNetConn)
NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize)), NewProcessor(r, circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize)),
&packets.ConnectPacket{ &packets.ConnectPacket{
ClientIdentifier: id, ClientIdentifier: id,
}, },
@@ -176,6 +177,8 @@ func BenchmarkServerClose(b *testing.B) {
s.Close() s.Close()
s.listeners.Delete("t1") s.listeners.Delete("t1")
} }
fmt.Println()
} }
/* /*
@@ -183,75 +186,17 @@ func BenchmarkServerClose(b *testing.B) {
* Server Establish Connection * Server Establish Connection
*/ */
func TestServerEstablishConnectionOKCleanSession(t *testing.T) { func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
for x := 0; x < 10000; x++ {
fmt.Println("===========================", x)
s := New()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
//time.Sleep(10 * time.Millisecond)
s.clients.Lock()
for _, v := range s.clients.internal {
v.close()
break
}
s.clients.Unlock()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
fmt.Println()
w.Close()
}
}
/*
func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
r2, _ := net.Pipe()
s := New() s := New()
px := NewProcessor(r2, circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
px.Start()
s.clients.internal["zen"] = newClient( s.clients.internal["zen"] = newClient(
px, NewProcessor(new(MockNetConn), circ.NewReader(32, 8), circ.NewWriter(32, 8)),
&packets.ConnectPacket{ClientIdentifier: "zen"}, &packets.ConnectPacket{ClientIdentifier: "zen"},
new(auth.Allow), new(auth.Allow),
) )
s.clients.internal["zen"].subscriptions = map[string]byte{ subs := map[string]byte{
"a/b/c": 1, "a/b/c": 1,
} }
s.clients.internal["zen"].subscriptions = subs
r, w := net.Pipe() r, w := net.Pipe()
o := make(chan error) o := make(chan error)
@@ -271,8 +216,6 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
'z', 'e', 'n', // Client ID "zen" 'z', 'e', 'n', // Client ID "zen"
}) })
w.Write([]byte{byte(packets.Disconnect << 4), 0}) w.Write([]byte{byte(packets.Disconnect << 4), 0})
r.Close()
s.clients.internal["zen"].close()
}() }()
recv := make(chan []byte) recv := make(chan []byte)
@@ -284,21 +227,29 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
recv <- buf recv <- buf
}() }()
require.NoError(t, <-o) s.clients.Lock()
for _, v := range s.clients.internal {
v.close()
break
}
s.clients.Unlock()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{ require.Equal(t, []byte{
byte(packets.Connack << 4), 2, byte(packets.Connack << 4), 2,
0, packets.Accepted, 0, packets.Accepted,
}, <-recv) }, <-recv)
require.Empty(t, s.clients.internal["zen"].subscriptions)
w.Close() w.Close()
} }
/*
func TestServerEstablishConnectionOKInheritSession(t *testing.T) { func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
r2, w2 := net.Pipe()
s := New() s := New()
s.clients.internal["zen"] = newClient( s.clients.internal["zen"] = newClient(
NewParser(r2, newBufioReader(r2), newBufioWriter(w2)), NewProcessor(new(MockNetConn), circ.NewReader(32, 8), circ.NewWriter(32, 8)),
&packets.ConnectPacket{ClientIdentifier: "zen"}, &packets.ConnectPacket{ClientIdentifier: "zen"},
new(auth.Allow), new(auth.Allow),
) )
@@ -325,8 +276,6 @@ func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
'z', 'e', 'n', // Client ID "zen" 'z', 'e', 'n', // Client ID "zen"
}) })
w.Write([]byte{byte(packets.Disconnect << 4), 0}) w.Write([]byte{byte(packets.Disconnect << 4), 0})
r.Close()
s.clients.internal["zen"].close()
}() }()
recv := make(chan []byte) recv := make(chan []byte)
@@ -347,9 +296,51 @@ func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
w.Close() w.Close()
} }
func TestServerEstablishConnectionSendLWT(t *testing.T) {
s := New()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Packet Flags
0, 45, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
})
}()
go func() {
_, err := ioutil.ReadAll(w)
if err != nil {
fmt.Println("read closed")
}
}()
time.Sleep(time.Millisecond * 10)
r.Close()
w.Close()
s.clients.Lock()
for _, v := range s.clients.internal {
v.close()
break
}
s.clients.Unlock()
require.Error(t, <-o)
}
func TestServerEstablishConnectionBadFixedHeader(t *testing.T) { func TestServerEstablishConnectionBadFixedHeader(t *testing.T) {
r, w := net.Pipe() r, w := net.Pipe()
go func() { go func() {
w.Write([]byte{99}) w.Write([]byte{99})
w.Close() w.Close()
@@ -363,22 +354,6 @@ func TestServerEstablishConnectionBadFixedHeader(t *testing.T) {
require.Equal(t, ErrReadConnectFixedHeader, err) require.Equal(t, ErrReadConnectFixedHeader, err)
} }
func TestServerEstablishConnectionBadConnectPacket(t *testing.T) {
r, w := net.Pipe()
go func() {
w.Write([]byte{byte(packets.Connect << 4), 17})
w.Close()
}()
s := New()
err := s.EstablishConnection("tcp", r, new(auth.Allow))
r.Close()
require.Error(t, err)
require.Equal(t, ErrReadConnectPacket, err)
}
func TestServerEstablishConnectionNotConnectPacket(t *testing.T) { func TestServerEstablishConnectionNotConnectPacket(t *testing.T) {
r, w := net.Pipe() r, w := net.Pipe()
@@ -396,29 +371,21 @@ func TestServerEstablishConnectionNotConnectPacket(t *testing.T) {
r.Close() r.Close()
require.Error(t, err) require.Error(t, err)
require.Equal(t, ErrFirstPacketInvalid, err) require.Equal(t, ErrReadConnectInvalid, err)
} }
func TestServerEstablishConnectionInvalidConnectPacket(t *testing.T) { func TestServerEstablishConnectionInvalidConnectPacket(t *testing.T) {
r, w := net.Pipe() r, w := net.Pipe()
go func() { go func() {
w.Write([]byte{ w.Write([]byte{byte(packets.Connect << 4), 17})
byte(packets.Connect << 4), 13, // Fixed header
0, 2, // Protocol Name - MSB+LSB
'M', 'Q', // ** NON-CONFORMING Protocol Name
4, // Protocol Version
2, // Packet Flags
0, 45, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
})
w.Close() w.Close()
}() }()
s := New() s := New()
err := s.EstablishConnection("tcp", r, new(auth.Allow)) err := s.EstablishConnection("tcp", r, new(auth.Allow))
r.Close() r.Close()
require.Error(t, err) require.Error(t, err)
require.Equal(t, ErrReadConnectInvalid, err) require.Equal(t, ErrReadConnectInvalid, err)
} }
@@ -452,75 +419,8 @@ func TestServerEstablishConnectionBadAuth(t *testing.T) {
require.Equal(t, ErrConnectNotAuthorized, err) require.Equal(t, ErrConnectNotAuthorized, err)
} }
func TestServerEstablishConnectionWriteClientError(t *testing.T) {
r, w := net.Pipe()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Packet Flags
0, 60, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
})
}()
o := make(chan error)
go func() {
s := New()
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
time.Sleep(5 * time.Millisecond)
w.Close()
require.Error(t, <-o)
r.Close()
}
func TestServerEstablishConnectionReadClientError(t *testing.T) {
r, w := net.Pipe()
o := make(chan error)
go func() {
s := New()
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags
0, 45, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
})
}()
go func() {
_, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
w.Close()
}()
time.Sleep(10 * time.Millisecond)
r.Close()
require.Error(t, <-o)
}
func TestResendInflight(t *testing.T) { func TestResendInflight(t *testing.T) {
s, _, _, cl := setupClient("zen") s, _, _, cl := setupClient("zen")
cl.p.W = new(quietWriter)
cl.inFlight.set(1, &inFlightMessage{ cl.inFlight.set(1, &inFlightMessage{
packet: &packets.PublishPacket{ packet: &packets.PublishPacket{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
@@ -544,16 +444,16 @@ func TestResendInflight(t *testing.T) {
'a', '/', 'b', '/', 'c', // Topic Name 'a', '/', 'b', '/', 'c', // Topic Name
0, 1, // packet id from qos=1 0, 1, // packet id from qos=1
'h', 'e', 'l', 'l', 'o', // Payload) 'h', 'e', 'l', 'l', 'o', // Payload)
}, cl.p.W.(*quietWriter).f[0]) }, cl.p.W.Get()[:16])
} }
func TestResendInflightWriteError(t *testing.T) { func TestResendInflightWriteError(t *testing.T) {
s, _, _, cl := setupClient("zen") s, _, _, cl := setupClient("zen")
cl.p.W = &quietWriter{errAfter: -1}
cl.inFlight.set(1, &inFlightMessage{ cl.inFlight.set(1, &inFlightMessage{
packet: &packets.PublishPacket{}, packet: &packets.PublishPacket{},
}) })
cl.p.W.Close()
err := s.resendInflight(cl) err := s.resendInflight(cl)
require.Error(t, err) require.Error(t, err)
} }
@@ -562,33 +462,33 @@ func TestResendInflightWriteError(t *testing.T) {
* Server Read Client * Server Read Client
*/ */
/*
func TestServerReadClientOK(t *testing.T) { func TestServerReadClientOK(t *testing.T) {
s, r, w, cl := setupClient("zen") s, r, w, cl := setupClient("zen")
go func() { go func() {
close(cl.end)
w.Write([]byte{ w.Write([]byte{
byte(packets.Publish << 4), 18, // Fixed header byte(packets.Publish << 4), 18, // Fixed header
0, 5, // Topic Name - LSB+MSB 0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name 'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
}) })
close(cl.end)
}()
time.Sleep(10 * time.Millisecond) }()
o := make(chan error) o := make(chan error)
go func() { go func() {
o <- s.readClient(cl) o <- s.readClient(cl)
}() }()
require.NoError(t, <-o)
w.Close() w.Close()
r.Close() r.Close()
require.NoError(t, <-o)
} }
/*
func TestServerReadClientNoConn(t *testing.T) { func TestServerReadClientNoConn(t *testing.T) {
s, r, _, cl := setupClient("zen") s, r, _, cl := setupClient("zen")
cl.p.Conn.Close() cl.p.Conn.Close()

View File

@@ -11,6 +11,12 @@ import (
"github.com/mochi-co/mqtt/packets" "github.com/mochi-co/mqtt/packets"
) )
var (
// ErrInsufficientBytes indicates that there were not enough bytes
// in the buffer to read/peek.
ErrInsufficientBytes = fmt.Errorf("Insufficient bytes in buffer")
)
// Processor reads and writes bytes to a network connection. // Processor reads and writes bytes to a network connection.
type Processor struct { type Processor struct {
@@ -52,54 +58,34 @@ func NewProcessor(c net.Conn, r *circ.Reader, w *circ.Writer) *Processor {
// Start spins up the reader and writer goroutines. // Start spins up the reader and writer goroutines.
func (p *Processor) Start() { func (p *Processor) Start() {
p.started.Add(2) p.started.Add(2)
fmt.Println("\t\tWG Started ADD", 2)
go func(start, end *sync.WaitGroup) { go func(start, end *sync.WaitGroup) {
defer p.endedW.Done() defer p.endedW.Done()
p.started.Done() p.started.Done()
p.W.WriteTo(p.Conn)
fmt.Println("starting writeTo", p.Conn)
n, err := p.W.WriteTo(p.Conn)
if err != nil {
// ...
}
fmt.Println(">>> finished WriteTo", n, err)
}(p.started, p.endedW) }(p.started, p.endedW)
p.endedW.Add(1) p.endedW.Add(1)
go func(start, end *sync.WaitGroup) { go func(start, end *sync.WaitGroup) {
defer p.endedR.Done() defer p.endedR.Done()
p.started.Done() p.started.Done()
p.R.ReadFrom(p.Conn)
fmt.Println("starting readFrom", p.Conn)
n, err := p.R.ReadFrom(p.Conn)
if err != nil {
// ...
}
//
fmt.Println(">>> finished ReadFrom", n, err)
}(p.started, p.endedR) }(p.started, p.endedR)
p.endedR.Add(1) p.endedR.Add(1)
fmt.Println("\t\tWG Started Waiting")
p.started.Wait() p.started.Wait()
fmt.Println("Started OK") fmt.Println("Started OK")
} }
// Stop stops the processor goroutines. // Stop stops the processor goroutines.
func (p *Processor) Stop() { func (p *Processor) Stop() {
p.R.Close()
p.W.Close() p.W.Close()
p.endedW.Wait() p.endedW.Wait()
if p.Conn != nil { if p.Conn != nil {
p.Conn.Close() p.Conn.Close()
} else {
fmt.Println("--// Conn is nil")
} }
p.R.Close()
p.endedR.Wait() p.endedR.Wait()
} }
@@ -132,7 +118,6 @@ func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
// The remaining length value can be up to 5 bytes. Peek through each byte // The remaining length value can be up to 5 bytes. Peek through each byte
// looking for continue values, and if found increase the peek. Otherwise // looking for continue values, and if found increase the peek. Otherwise
// decode the bytes that were legit. // decode the bytes that were legit.
//p.fhBuffer = p.fhBuffer[:0]
buf := make([]byte, 0, 6) buf := make([]byte, 0, 6)
i := 1 i := 1
var b int64 = 2 // need this var later. var b int64 = 2 // need this var later.
@@ -143,6 +128,9 @@ func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
} }
// Add the byte to the length bytes slice. // Add the byte to the length bytes slice.
if i >= len(peeked) {
return ErrInsufficientBytes
}
buf = append(buf, peeked[i]) buf = append(buf, peeked[i])
// If it's not a continuation flag, end here. // If it's not a continuation flag, end here.
@@ -150,7 +138,7 @@ func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
break break
} }
// If i has reached 4 without a length terminator, throw a protocol violation. // If i has reached 4 without a length terminator, return a protocol violation.
i++ i++
if i == 4 { if i == 4 {
return packets.ErrOversizedLengthIndicator return packets.ErrOversizedLengthIndicator
@@ -217,18 +205,6 @@ func (p *Processor) Read() (pk packets.Packet, err error) {
return pk, err return pk, err
} }
/*
bt, err := p.R.Peek(p.FixedHeader.Remaining)
if err != nil {
return pk, err
}
err = p.R.CommitTail(p.FixedHeader.Remaining)
if err != nil {
return err
}
*/
// Decode the remaining packet values using a fresh copy of the bytes, // Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one. // otherwise the next packet will change the data of this one.
// ---- // ----

View File

@@ -1,6 +1,7 @@
package mqtt package mqtt
import ( import (
"fmt"
"net" "net"
"testing" "testing"
"time" "time"
@@ -38,8 +39,7 @@ func TestProcessorRefreshDeadline(t *testing.T) {
} }
func BenchmarkProcessorRefreshDeadline(b *testing.B) { func BenchmarkProcessorRefreshDeadline(b *testing.B) {
conn := new(MockNetConn) p := NewProcessor(new(MockNetConn), circ.NewReader(16, 4), circ.NewWriter(16, 4))
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
p.RefreshDeadline(10) p.RefreshDeadline(10)
@@ -47,31 +47,75 @@ func BenchmarkProcessorRefreshDeadline(b *testing.B) {
} }
func TestProcessorReadFixedHeader(t *testing.T) { func TestProcessorReadFixedHeader(t *testing.T) {
conn := new(MockNetConn) p := NewProcessor(new(MockNetConn), circ.NewReader(16, 4), circ.NewWriter(16, 4))
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
// Test null data.
fh := new(packets.FixedHeader) fh := new(packets.FixedHeader)
err := p.ReadFixedHeader(fh)
require.Error(t, err)
// Test insufficient peeking.
fh = new(packets.FixedHeader)
p.R.Set([]byte{packets.Connect << 4}, 0, 1)
p.R.SetPos(0, 1)
err = p.ReadFixedHeader(fh)
require.Error(t, err)
fh = new(packets.FixedHeader)
p.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2) p.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
p.R.SetPos(0, 2) p.R.SetPos(0, 2)
err = p.ReadFixedHeader(fh) fmt.Println("c")
err := p.ReadFixedHeader(fh)
require.NoError(t, err) require.NoError(t, err)
tail, head := p.R.GetPos() tail, head := p.R.GetPos()
require.Equal(t, int64(2), tail) require.Equal(t, int64(2), tail)
require.Equal(t, int64(2), head) require.Equal(t, int64(2), head)
p.Stop()
}
func TestProcessorReadFixedHeaderPeekError(t *testing.T) {
p := NewProcessor(new(MockNetConn), circ.NewReader(16, 4), circ.NewWriter(16, 4))
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
o <- p.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
p.Stop()
require.Error(t, <-o)
}
func TestProcessorReadFixedHeaderDecodeError(t *testing.T) {
p := NewProcessor(new(MockNetConn), circ.NewReader(16, 4), circ.NewWriter(16, 4))
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
p.R.Set([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}, 0, 2)
p.R.SetPos(0, 2)
o <- p.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
p.Stop()
require.Error(t, <-o)
}
func TestProcessorReadFixedHeaderNoLengthTerminator(t *testing.T) {
p := NewProcessor(new(MockNetConn), circ.NewReader(16, 4), circ.NewWriter(16, 4))
fh := new(packets.FixedHeader)
p.R.Set([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 0, 5)
p.R.SetPos(0, 5)
err := p.ReadFixedHeader(fh)
require.Error(t, err)
require.Equal(t, packets.ErrOversizedLengthIndicator, err)
p.Stop()
}
func TestProcessorReadFixedHeaderInsufficientPeek(t *testing.T) {
p := NewProcessor(new(MockNetConn), circ.NewReader(16, 4), circ.NewWriter(16, 4))
fh := new(packets.FixedHeader)
p.R.Set([]byte{packets.Connect << 4}, 0, 1)
p.R.SetPos(0, 1)
err := p.ReadFixedHeader(fh)
require.Error(t, err)
p.Stop()
} }
func TestProcessorRead(t *testing.T) { func TestProcessorRead(t *testing.T) {