bytes buffer to pool

This commit is contained in:
Mochi
2019-12-01 21:44:05 +00:00
parent 7331d93ada
commit 7fa7fffc89
14 changed files with 236 additions and 228 deletions

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"fmt" "fmt"
"log" "log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
@@ -11,9 +12,14 @@ import (
"github.com/mochi-co/mqtt" "github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/internal/listeners" "github.com/mochi-co/mqtt/internal/listeners"
_ "net/http/pprof"
) )
func main() { func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
done := make(chan bool, 1) done := make(chan bool, 1)

1
go.mod
View File

@@ -6,6 +6,7 @@ require (
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23 github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164 github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164
github.com/pkg/profile v1.4.0
github.com/rs/xid v1.2.1 github.com/rs/xid v1.2.1
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.4.0
) )

2
go.sum
View File

@@ -8,6 +8,8 @@ github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8 h1:BIY2BMCLHm6hE/SU
github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8/go.mod h1:AqE7zHPhLOj61seX0vXvzpGiD9Q3Bx5LQPf/FleHKWc= github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8/go.mod h1:AqE7zHPhLOj61seX0vXvzpGiD9Q3Bx5LQPf/FleHKWc=
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164 h1:XGYo79ZRE9pQE9B5iZCYw3VLaq88PfxcdvDf9crG+dQ= github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164 h1:XGYo79ZRE9pQE9B5iZCYw3VLaq88PfxcdvDf9crG+dQ=
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164/go.mod h1:LfBrWXdsMaDKL0ZjcbnLjeYL48Nlo1nW4MltMDYqr44= github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164/go.mod h1:LfBrWXdsMaDKL0ZjcbnLjeYL48Nlo1nW4MltMDYqr44=
github.com/pkg/profile v1.4.0 h1:uCmaf4vVbWAOZz36k1hrQD7ijGRzLwaME8Am/7a4jZI=
github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc= github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc=

View File

@@ -8,8 +8,8 @@ import (
) )
var ( var (
DefaultBufferSize int = 2048 // the default size of the buffer in bytes. DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes.
DefaultBlockSize int = 128 // the default size per R/W block in bytes. DefaultBlockSize int = 1024 * 8 // the default size per R/W block in bytes.
ErrOutOfRange = fmt.Errorf("Indexes out of range") ErrOutOfRange = fmt.Errorf("Indexes out of range")
ErrInsufficientBytes = fmt.Errorf("Insufficient bytes to return") ErrInsufficientBytes = fmt.Errorf("Insufficient bytes to return")
@@ -43,6 +43,7 @@ func NewBuffer(size, block int) Buffer {
if block == 0 { if block == 0 {
block = DefaultBlockSize block = DefaultBlockSize
} }
if size < 2*block { if size < 2*block {
size = 2 * block size = 2 * block
} }
@@ -52,12 +53,32 @@ func NewBuffer(size, block int) Buffer {
mask: size - 1, mask: size - 1,
block: block, block: block,
buf: make([]byte, size), buf: make([]byte, size),
tmp: make([]byte, size),
rcond: sync.NewCond(new(sync.Mutex)), rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)), wcond: sync.NewCond(new(sync.Mutex)),
} }
} }
// NewBufferFromSlice returns a new instance of buffer using a
// pre-existing byte slice.
func NewBufferFromSlice(block int, buf []byte) Buffer {
l := len(buf)
if block == 0 {
block = DefaultBlockSize
}
b := Buffer{
size: l,
mask: l - 1,
block: block,
buf: buf,
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
return b
}
// Get will return the tail and head positions of the buffer. // Get will return the tail and head positions of the buffer.
// This method is for use with testing. // This method is for use with testing.
func (b *Buffer) GetPos() (int64, int64) { func (b *Buffer) GetPos() (int64, int64) {

View File

@@ -36,6 +36,20 @@ func TestNewBufferUndersize(t *testing.T) {
require.Equal(t, DefaultBlockSize, buf.block) require.Equal(t, DefaultBlockSize, buf.block)
} }
func TestNewBufferFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestNewBufferFromSlice0Size(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(0, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestGetPos(t *testing.T) { func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4) buf := NewBuffer(16, 4)
tail, head := buf.GetPos() tail, head := buf.GetPos()

32
internal/circ/pool.go Normal file
View File

@@ -0,0 +1,32 @@
package circ
import (
"sync"
)
// BytesPool is a pool of []byte
type BytesPool struct {
pool sync.Pool
}
// NewBytesPool returns a sync.pool of []byte
func NewBytesPool(n int) BytesPool {
return BytesPool{
pool: sync.Pool{
New: func() interface{} {
return make([]byte, n)
},
},
}
}
// Get returns a pooled bytes.Buffer.
func (b BytesPool) Get() []byte {
return b.pool.Get().([]byte)
}
// Put puts the byte slice back into the pool.
func (b BytesPool) Put(x []byte) {
x = x[:0]
b.pool.Put(x)
}

View File

@@ -0,0 +1,46 @@
package circ
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBytesPool(t *testing.T) {
bpool := NewBytesPool(256)
require.NotNil(t, bpool.pool)
}
func BenchmarkNewBytesPool(b *testing.B) {
for n := 0; n < b.N; n++ {
NewBytesPool(256)
}
}
func TestNewBytesPoolGet(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, make([]byte, 256), buf)
}
func BenchmarkBytesPoolGet(b *testing.B) {
bpool := NewBytesPool(256)
for n := 0; n < b.N; n++ {
bpool.Get()
}
}
func TestNewBytesPoolPut(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
bpool.Put(buf)
}
func BenchmarkBytesPoolPut(b *testing.B) {
bpool := NewBytesPool(256)
buf := bpool.Get()
for n := 0; n < b.N; n++ {
bpool.Put(buf)
}
}

View File

@@ -3,8 +3,6 @@ package circ
import ( import (
"io" "io"
"sync/atomic" "sync/atomic"
dbg "github.com/mochi-co/debug"
) )
// Reader is a circular buffer for reading data from an io.Reader. // Reader is a circular buffer for reading data from an io.Reader.
@@ -12,7 +10,7 @@ type Reader struct {
Buffer Buffer
} }
// NewReader returns a pointer to a new Circular Reader. // NewReader returns a new Circular Reader.
func NewReader(size, block int) *Reader { func NewReader(size, block int) *Reader {
b := NewBuffer(size, block) b := NewBuffer(size, block)
b.ID = "\treader" b.ID = "\treader"
@@ -21,6 +19,16 @@ func NewReader(size, block int) *Reader {
} }
} }
// NewReaderFromSlice returns a new Circular Reader using a pre-exising
// byte slice.
func NewReaderFromSlice(block int, p []byte) *Reader {
b := NewBufferFromSlice(block, p)
b.ID = "\treader"
return &Reader{
b,
}
}
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when // ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so. // there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) { func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
@@ -48,8 +56,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
end = b.size end = b.size
} }
dbg.Println(dbg.Yellow, b.ID, "b.ReadFrom allocating", start, ":", end)
// Read into the buffer between the start and end indexes only. // Read into the buffer between the start and end indexes only.
n, err := r.Read(b.buf[start:end]) n, err := r.Read(b.buf[start:end])
total += int64(n) // incr total bytes read. total += int64(n) // incr total bytes read.
@@ -57,8 +63,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
return total, nil return total, nil
} }
dbg.Println(dbg.HiYellow, b.ID, "b.ReadFrom received", n, b.buf[start:start+n])
// Move the head forward however many bytes were read. // Move the head forward however many bytes were read.
atomic.AddInt64(&b.head, int64(n)) atomic.AddInt64(&b.head, int64(n))
@@ -71,8 +75,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
// Read reads n bytes from the buffer, and will block until at n bytes // Read reads n bytes from the buffer, and will block until at n bytes
// exist in the buffer to read. // exist in the buffer to read.
func (b *Buffer) Read(n int) (p []byte, err error) { func (b *Buffer) Read(n int) (p []byte, err error) {
dbg.Println(dbg.Cyan, b.ID, "b.Read waiting for", n, "bytes")
err = b.awaitFilled(n) err = b.awaitFilled(n)
if err != nil { if err != nil {
return return
@@ -90,7 +92,5 @@ func (b *Buffer) Read(n int) (p []byte, err error) {
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read. b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
} }
dbg.Println(dbg.HiCyan, b.ID, "b.Read read", tail, next, b.tmp)
return b.tmp, nil return b.tmp, nil
} }

View File

@@ -20,6 +20,13 @@ func TestNewReader(t *testing.T) {
require.Equal(t, block, buf.block) require.Equal(t, block, buf.block)
} }
func TestNewReaderFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewReaderFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestReadFrom(t *testing.T) { func TestReadFrom(t *testing.T) {
buf := NewReader(16, 4) buf := NewReader(16, 4)

View File

@@ -19,6 +19,16 @@ func NewWriter(size, block int) *Writer {
} }
} }
// NewWriterFromSlice returns a new Circular Writer using a pre-exising
// byte slice.
func NewWriterFromSlice(block int, p []byte) *Writer {
b := NewBufferFromSlice(block, p)
b.ID = "writer"
return &Writer{
b,
}
}
// WriteTo writes the contents of the buffer to an io.Writer. // WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int, err error) { func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
atomic.StoreInt64(&b.State, 2) atomic.StoreInt64(&b.State, 2)

View File

@@ -22,6 +22,13 @@ func TestNewWriter(t *testing.T) {
require.Equal(t, block, buf.block) require.Equal(t, block, buf.block)
} }
func TestNewWriterFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewWriterFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestWriteTo(t *testing.T) { func TestWriteTo(t *testing.T) {
tests := []struct { tests := []struct {
tail int64 tail int64

View File

@@ -12,7 +12,6 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
dbg "github.com/mochi-co/debug"
"github.com/mochi-co/mqtt/internal/auth" "github.com/mochi-co/mqtt/internal/auth"
"github.com/mochi-co/mqtt/internal/circ" "github.com/mochi-co/mqtt/internal/circ"
"github.com/mochi-co/mqtt/internal/packets" "github.com/mochi-co/mqtt/internal/packets"
@@ -210,27 +209,26 @@ func (cl *Client) Start() {
go func() { go func() {
cl.state.started.Done() cl.state.started.Done()
_, err := cl.w.WriteTo(cl.conn) //_, err :=
dbg.Println(dbg.HiRed, cl.ID, "WriteTo stopped", err) cl.w.WriteTo(cl.conn)
cl.state.endedW.Done() cl.state.endedW.Done()
//cl.close() //cl.close()
}() }()
cl.state.endedW.Add(1) cl.state.endedW.Add(1)
go func() { go func() {
cl.state.started.Done() cl.state.started.Done()
_, err := cl.r.ReadFrom(cl.conn) //_, err :=
dbg.Println(dbg.HiRed, cl.ID, "ReadFrom stopped", err) cl.r.ReadFrom(cl.conn)
cl.state.endedR.Done() cl.state.endedR.Done()
//cl.close() //cl.close()
}() }()
cl.state.endedR.Add(1) cl.state.endedR.Add(1)
cl.state.started.Wait() cl.state.started.Wait()
} }
// Stop instructs the client to shut down all processing goroutines and disconnect. // Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop() { func (cl *Client) Stop() {
dbg.Println(dbg.HiRed+"CLIENT stop called...", dbg.Underline+cl.ID)
cl.r.Stop() cl.r.Stop()
cl.w.Stop() cl.w.Stop()
cl.state.endedW.Wait() cl.state.endedW.Wait()
@@ -241,7 +239,6 @@ func (cl *Client) Stop() {
} }
cl.state.endedR.Wait() cl.state.endedR.Wait()
dbg.Println(dbg.HiRed+"CLIENT stopped", dbg.Underline+cl.ID)
} }
// readFixedHeader reads in the values of the next packet's fixed header. // readFixedHeader reads in the values of the next packet's fixed header.

135
mqtt.go
View File

@@ -23,28 +23,11 @@ var (
ErrListenerIDExists = errors.New("Listener id already exists") ErrListenerIDExists = errors.New("Listener id already exists")
ErrReadConnectInvalid = errors.New("Connect packet was not valid") ErrReadConnectInvalid = errors.New("Connect packet was not valid")
ErrConnectNotAuthorized = errors.New("Connect packet was not authorized") ErrConnectNotAuthorized = errors.New("Connect packet was not authorized")
// ErrACLNotAuthorized = errors.New("ACL not authorized")
) )
/*
var (
ErrListenerIDExists = errors.New("Listener id already exists")
ErrReadConnectFixedHeader = errors.New("Error reading fixed header on CONNECT packet")
ErrReadConnectPacket = errors.New("Error reading CONNECT packet")
ErrReadConnectInvalid = errors.New("CONNECT packet was not valid")
ErrReadFixedHeader = errors.New("Error reading fixed header")
ErrReadPacketPayload = errors.New("Error reading packet payload")
ErrReadPacketValidation = errors.New("Error validating packet")
ErrConnectionClosed = errors.New("Connection not open")
ErrNoData = errors.New("No data")
ErrACLNotAuthorized = errors.New("ACL not authorized")
)
*/
// Server is an MQTT broker server. // Server is an MQTT broker server.
type Server struct { type Server struct {
bytepool circ.BytesPool
Listeners listeners.Listeners // listeners listen for new connections. Listeners listeners.Listeners // listeners listen for new connections.
Clients clients.Clients // clients known to the broker. Clients clients.Clients // clients known to the broker.
Topics *topics.Index // an index of topic subscriptions and retained messages. Topics *topics.Index // an index of topic subscriptions and retained messages.
@@ -52,8 +35,8 @@ type Server struct {
// New returns a new instance of an MQTT broker. // New returns a new instance of an MQTT broker.
func New() *Server { func New() *Server {
fmt.Println()
return &Server{ return &Server{
bytepool: circ.NewBytesPool(circ.DefaultBufferSize),
Listeners: listeners.New(), Listeners: listeners.New(),
Clients: clients.New(), Clients: clients.New(),
Topics: topics.New(), Topics: topics.New(),
@@ -82,53 +65,14 @@ func (s *Server) Serve() error {
return nil return nil
} }
// Close attempts to gracefully shutdown the server, all listeners, and clients.
func (s *Server) Close() error {
s.Listeners.CloseAll(s.closeListenerClients)
return nil
}
// closeListenerClients closes all clients on the specified listener.
func (s *Server) closeListenerClients(listener string) {
clients := s.Clients.GetByListener(listener)
for _, client := range clients {
s.closeClient(client, false) // omit errors
}
}
// closeClient closes a client connection and publishes any LWT messages.
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
//debug.Println(cl.ID, "SERVER STOPS ISSUED >> ")
// If an LWT message is set, publish it to the topic subscribers.
/* // this currently loops forever on broken connection
if sendLWT && cl.lwt.topic != "" {
err := s.processPublish(cl, &packets.PublishPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Retain: cl.lwt.retain,
Qos: cl.lwt.qos,
},
TopicName: cl.lwt.topic,
Payload: cl.lwt.message,
})
if err != nil {
return err
}
}
*/
// Stop listening for new packets.
cl.Stop()
return nil
}
// EstablishConnection establishes a new client connection with the broker. // EstablishConnection establishes a new client connection with the broker.
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error { func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
client := clients.NewClient(c, circ.NewReader(0, 0), circ.NewWriter(0, 0)) //client := clients.NewClient(c, circ.NewReader(0, 0), circ.NewWriter(0, 0))
client := clients.NewClient(c,
circ.NewReaderFromSlice(0, s.bytepool.Get()),
circ.NewWriterFromSlice(0, s.bytepool.Get()),
)
client.Start() client.Start()
fh := new(packets.FixedHeader) fh := new(packets.FixedHeader)
@@ -151,6 +95,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
if !ac.Authenticate(pk.Username, pk.Password) { if !ac.Authenticate(pk.Username, pk.Password) {
retcode = packets.CodeConnectBadAuthValues retcode = packets.CodeConnectBadAuthValues
} }
var sessionPresent bool var sessionPresent bool
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok { if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
existing.Lock() existing.Lock()
@@ -168,10 +113,8 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
existing.Unlock() existing.Unlock()
} }
// Add the new client to the clients manager.
s.Clients.Add(client) s.Clients.Add(client)
// Send a CONNACK back to the client with retcode.
err = s.writeClient(client, packets.Packet{ err = s.writeClient(client, packets.Packet{
FixedHeader: packets.FixedHeader{ FixedHeader: packets.FixedHeader{
Type: packets.Connack, Type: packets.Connack,
@@ -184,21 +127,17 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
return err return err
} }
// Resend any unacknowledged QOS messages still pending for the client. err = s.ResendInflight(client)
/*err = s.resendInflight(client)
if err != nil { if err != nil {
return err return err
} }
*/
// Block and listen for more packets, and end if an error or nil packet occurs.
var sendLWT bool var sendLWT bool
err = client.Read(s.processPacket) err = client.Read(s.processPacket)
if err != nil { if err != nil {
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3] sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
} }
// Publish last will and testament then close.
s.closeClient(client, sendLWT) s.closeClient(client, sendLWT)
return err return err
@@ -217,13 +156,10 @@ func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
return nil return nil
} }
// resendInflight republishes any inflight messages to the client. // ResendInflight republishes any inflight messages to the client.
/*func (s *Server) resendInflight(cl *clients.Client) error { func (s *Server) ResendInflight(cl *clients.Client) error {
cl.RLock() for _, pk := range cl.InFlight.GetAll() {
msgs := cl.inFlight.internal err := s.writeClient(cl, pk.Packet)
cl.RUnlock()
for _, msg := range msgs {
err := s.writeClient(cl, msg.packet)
if err != nil { if err != nil {
return err return err
} }
@@ -231,7 +167,6 @@ func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
return nil return nil
} }
*/
// processPacket processes an inbound packet for a client. Since the method is // processPacket processes an inbound packet for a client. Since the method is
// typically called as a goroutine, errors are mostly for test checking purposes. // typically called as a goroutine, errors are mostly for test checking purposes.
@@ -453,3 +388,45 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) (clos
return return
} }
// Close attempts to gracefully shutdown the server, all listeners, and clients.
func (s *Server) Close() error {
s.Listeners.CloseAll(s.closeListenerClients)
return nil
}
// closeListenerClients closes all clients on the specified listener.
func (s *Server) closeListenerClients(listener string) {
clients := s.Clients.GetByListener(listener)
for _, client := range clients {
s.closeClient(client, false) // omit errors
}
}
// closeClient closes a client connection and publishes any LWT messages.
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
// If an LWT message is set, publish it to the topic subscribers.
/*
if sendLWT && cl.lwt.topic != "" {
err := s.processPublish(cl, &packets.PublishPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Retain: cl.lwt.retain,
Qos: cl.lwt.qos,
},
TopicName: cl.lwt.topic,
Payload: cl.lwt.message,
})
if err != nil {
return err
}
}
*/
// Stop listening for new packets.
cl.Stop()
return nil
}

View File

@@ -819,7 +819,6 @@ func TestServerProcessSubscribeWriteError(t *testing.T) {
func TestServerProcessUnsubscribe(t *testing.T) { func TestServerProcessUnsubscribe(t *testing.T) {
s, cl, r, w := setupClient() s, cl, r, w := setupClient()
s.Clients.Add(cl) s.Clients.Add(cl)
s.Topics.Subscribe("a/b/c", cl.ID, 0) s.Topics.Subscribe("a/b/c", cl.ID, 0)
s.Topics.Subscribe("d/e/f", cl.ID, 1) s.Topics.Subscribe("d/e/f", cl.ID, 1)
@@ -880,133 +879,22 @@ func TestServerProcessUnsubscribeWriteError(t *testing.T) {
require.Equal(t, false, close) require.Equal(t, false, close)
} }
/* func TestServerClose(t *testing.T) {
s, cl, _, _ := setupClient()
cl.Listener = "t1"
s.Clients.Add(cl)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"), nil)
func TestServerProcessSubscribeWriteRetainedError(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.p.W = &quietWriter{errAfter: 1}
s.topics.RetainMessage(&packets.PublishPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Retain: true,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
})
require.Equal(t, 1, len(s.topics.Messages("a/b/c")))
err := s.processPacket(cl, &packets.SubscribePacket{
FixedHeader: packets.FixedHeader{
Type: packets.Subscribe,
},
PacketID: 10,
Topics: []string{"a/b/c", "d/e/f"},
Qoss: []byte{0, 1},
})
require.Error(t, err)
}
func TestServerProcessUnsubscribe(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.p.W = new(quietWriter)
s.clients.add(cl)
s.topics.Subscribe("a/b/c", cl.id, 0)
s.topics.Subscribe("d/e/f", cl.id, 1)
cl.noteSubscription("a/b/c", 0)
cl.noteSubscription("d/e/f", 1)
err := s.processPacket(cl, &packets.UnsubscribePacket{
FixedHeader: packets.FixedHeader{
Type: packets.Unsubscribe,
},
PacketID: 12,
Topics: []string{"a/b/c", "d/e/f"},
})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []byte{ s.Serve()
byte(packets.Unsuback << 4), 2, // Fixed header time.Sleep(time.Millisecond)
0, 12, // Packet ID - LSB+MSB require.Equal(t, 1, s.Listeners.Len())
}, cl.p.W.(*quietWriter).f[0])
require.Empty(t, s.topics.Subscribers("a/b/c")) listener, ok := s.Listeners.Get("t1")
require.Empty(t, s.topics.Subscribers("d/e/f")) require.Equal(t, true, ok)
require.NotContains(t, cl.subscriptions, "a/b/c") require.Equal(t, true, listener.(*listeners.MockListener).IsServing)
require.NotContains(t, cl.subscriptions, "d/e/f")
s.Close()
time.Sleep(time.Millisecond)
require.Equal(t, false, listener.(*listeners.MockListener).IsServing)
} }
func BenchmarkServerProcessUnsubscribe(b *testing.B) {
s, _, _, cl := setupClient("zen")
cl.p.W = new(quietWriter)
pk := &packets.UnsubscribePacket{
FixedHeader: packets.FixedHeader{
Type: packets.Unsubscribe,
},
PacketID: 12,
Topics: []string{"a/b/c"},
}
for n := 0; n < b.N; n++ {
err := s.processUnsubscribe(cl, pk)
if err != nil {
panic(err)
}
}
}
func TestServerProcessUnsubscribeWriteError(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.p.W = &quietWriter{errAfter: -1}
err := s.processPacket(cl, &packets.UnsubscribePacket{
FixedHeader: packets.FixedHeader{
Type: packets.Unsubscribe,
},
})
require.Error(t, err)
}
*/
/*
func TestResendInflight(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.inFlight.set(1, &inFlightMessage{
packet: &packets.PublishPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: 1,
Retain: true,
Dup: true,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
PacketID: 1,
},
sent: time.Now().Unix(),
})
err := s.resendInflight(cl)
require.NoError(t, err)
require.Equal(t, []byte{
byte(packets.Publish<<4 | 11), 14, // Fixed header QoS : 1
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
0, 1, // packet id from qos=1
'h', 'e', 'l', 'l', 'o', // Payload)
}, cl.p.W.Get()[:16])
}
func TestResendInflightWriteError(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.inFlight.set(1, &inFlightMessage{
packet: &packets.PublishPacket{},
})
cl.p.W.Close()
err := s.resendInflight(cl)
require.Error(t, err)
}
*/