diff --git a/mqtt.go b/mqtt.go index 58f29fa..042d063 100644 --- a/mqtt.go +++ b/mqtt.go @@ -1,6 +1,7 @@ package mqtt import ( + "bufio" "errors" "log" "net" @@ -38,6 +39,9 @@ var ( // clientKeepalive is the default keepalive time in seconds. clientKeepalive uint16 = 60 + + // rwBufSize is the size of client read/write buffers. + rwBufSize = 512 ) /* @@ -75,12 +79,6 @@ type Server struct { // buffers is a pool of bytes.buffer. buffers pools.BytesBuffersPool - - // readers is a pool of bufio.reader. - readers pools.BufioReadersPool - - // writers is a pool of bufio.writer. - writers pools.BufioWritersPool } // New returns a pointer to a new instance of the MQTT broker. @@ -90,8 +88,6 @@ func New() *Server { clients: newClients(), topics: trie.New(), buffers: pools.NewBytesBuffersPool(), - readers: pools.NewBufioReadersPool(512), - writers: pools.NewBufioWritersPool(512), } } @@ -128,11 +124,12 @@ func (s *Server) EstablishConnection(c net.Conn, ac auth.Controller) error { log.Println("connecting") // Create a new packets parser which will parse all packets for this client, - // using buffered writers and readers from the pool. - r, w := s.readers.Get(c), s.writers.Get(c) - defer s.readers.Put(r) - defer s.writers.Put(w) - p := packets.NewParser(c, r, w) + // using buffered writers and readers. + p := packets.NewParser( + c, + bufio.NewReaderSize(c, rwBufSize), + bufio.NewWriterSize(c, rwBufSize), + ) // Pull the header from the first packet and check for a CONNECT message. fh := new(packets.FixedHeader) @@ -515,7 +512,8 @@ func newClient(p *packets.Parser, pk *packets.ConnectPacket) *client { // nextPacketID returns the next packet id for a client, looping back to 0 // if the maximum ID has been reached. func (cl *client) nextPacketID() uint32 { - if cl.packetID == uint32(65535) || cl.packetID == uint32(0) { + i := atomic.LoadUint32(&cl.packetID) + if i == uint32(65535) || i == uint32(0) { atomic.StoreUint32(&cl.packetID, 1) return 1 } @@ -526,6 +524,7 @@ func (cl *client) nextPacketID() uint32 { // close attempts to gracefully close a client connection. func (cl *client) close() { cl.done.Do(func() { + // Signal to stop lsitening for packets. close(cl.end) diff --git a/mqtt_test.go b/mqtt_test.go index 0a747f0..095388a 100644 --- a/mqtt_test.go +++ b/mqtt_test.go @@ -29,8 +29,6 @@ func TestNew(t *testing.T) { require.NotNil(t, s.listeners) require.NotNil(t, s.clients) require.NotNil(t, s.buffers) - require.NotNil(t, s.readers) - require.NotNil(t, s.writers) log.Println(s) } @@ -806,7 +804,16 @@ func TestNextPacketID(t *testing.T) { client.packetID = uint32(65534) require.Equal(t, uint32(65535), client.nextPacketID()) require.Equal(t, uint32(1), client.nextPacketID()) +} +func BenchmarkNextPacketID(b *testing.B) { + r, w := net.Pipe() + p := packets.NewParser(r, newBufioReader(r), newBufioWriter(w)) + client := newClient(p, new(packets.ConnectPacket)) + + for n := 0; n < b.N; n++ { + client.nextPacketID() + } } func TestClientClose(t *testing.T) {