Packet encoding optimization (#343)

* Dynamically allocate buffer for writes if needed

* Remove unused net.Buffer

* Return bytes written to buffer instead of conn

* Dynamic write buffer

* Reduce double write of pk.Payload

* Use memory pool for packet encode

* Pool doesn't guarantee value between Put and Get

* Add benchmark for bufpool

* Fix issue #346

* Change default pool not to have size cap

---------

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
This commit is contained in:
thedevop
2023-12-21 15:23:35 -08:00
committed by GitHub
parent 4c682384c5
commit c6c7c296f6
3 changed files with 227 additions and 26 deletions

81
mempool/bufpool.go Normal file
View File

@@ -0,0 +1,81 @@
package mempool
import (
"bytes"
"sync"
)
var bufPool = NewBuffer(0)
// GetBuffer takes a Buffer from the default buffer pool
func GetBuffer() *bytes.Buffer { return bufPool.Get() }
// PutBuffer returns Buffer to the default buffer pool
func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) }
type BufferPool interface {
Get() *bytes.Buffer
Put(x *bytes.Buffer)
}
// NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will
// return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If
// max <= 0, no limit will be enforced.
func NewBuffer(max int) BufferPool {
if max > 0 {
return newBufferWithCap(max)
}
return newBuffer()
}
// Buffer is a Buffer pool.
type Buffer struct {
pool *sync.Pool
}
func newBuffer() *Buffer {
return &Buffer{
pool: &sync.Pool{
New: func() any { return new(bytes.Buffer) },
},
}
}
// Get a Buffer from the pool.
func (b *Buffer) Get() *bytes.Buffer {
return b.pool.Get().(*bytes.Buffer)
}
// Put the Buffer back into pool. It resets the Buffer for reuse.
func (b *Buffer) Put(x *bytes.Buffer) {
x.Reset()
b.pool.Put(x)
}
// BufferWithCap is a Buffer pool that
type BufferWithCap struct {
bp *Buffer
max int
}
func newBufferWithCap(max int) *BufferWithCap {
return &BufferWithCap{
bp: newBuffer(),
max: max,
}
}
// Get a Buffer from the pool.
func (b *BufferWithCap) Get() *bytes.Buffer {
return b.bp.Get()
}
// Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer
// for reuse.
func (b *BufferWithCap) Put(x *bytes.Buffer) {
if x.Cap() > b.max {
return
}
b.bp.Put(x)
}

96
mempool/bufpool_test.go Normal file
View File

@@ -0,0 +1,96 @@
package mempool
import (
"bytes"
"reflect"
"runtime/debug"
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBuffer(t *testing.T) {
defer debug.SetGCPercent(debug.SetGCPercent(-1))
bp := NewBuffer(1000)
require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String())
bp = NewBuffer(0)
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
bp = NewBuffer(-1)
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
}
func TestBuffer(t *testing.T) {
defer debug.SetGCPercent(debug.SetGCPercent(-1))
Size := 101
bp := NewBuffer(0)
buf := bp.Get()
for i := 0; i < Size; i++ {
buf.WriteByte('a')
}
bp.Put(buf)
buf = bp.Get()
require.Equal(t, 0, buf.Len())
}
func TestBufferWithCap(t *testing.T) {
defer debug.SetGCPercent(debug.SetGCPercent(-1))
Size := 101
bp := NewBuffer(100)
buf := bp.Get()
for i := 0; i < Size; i++ {
buf.WriteByte('a')
}
bp.Put(buf)
buf = bp.Get()
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.Cap())
}
func BenchmarkBufferPool(b *testing.B) {
bp := NewBuffer(0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
b := bp.Get()
b.WriteString("this is a test")
bp.Put(b)
}
}
func BenchmarkBufferPoolWithCapLarger(b *testing.B) {
bp := NewBuffer(64 * 1024)
b.ResetTimer()
for i := 0; i < b.N; i++ {
b := bp.Get()
b.WriteString("this is a test")
bp.Put(b)
}
}
func BenchmarkBufferPoolWithCapLesser(b *testing.B) {
bp := NewBuffer(10)
b.ResetTimer()
for i := 0; i < b.N; i++ {
b := bp.Get()
b.WriteString("this is a test")
bp.Put(b)
}
}
func BenchmarkBufferWithoutPool(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
b := new(bytes.Buffer)
b.WriteString("this is a test")
_ = b
}
}

View File

@@ -12,6 +12,8 @@ import (
"strconv"
"strings"
"sync"
"github.com/mochi-mqtt/server/v2/mempool"
)
// All valid packet types and their packet identifiers.
@@ -298,7 +300,8 @@ func (s *Subscription) decode(b byte) {
// ConnectEncode encodes a connect packet.
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeBytes(pk.Connect.ProtocolName))
nb.WriteByte(pk.ProtocolVersion)
@@ -315,7 +318,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
nb.Write(encodeUint16(pk.Connect.Keepalive))
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -324,7 +328,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.Connect.WillFlag {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -493,12 +498,14 @@ func (pk *Packet) ConnectValidate() Code {
// ConnackEncode encodes a Connack packet.
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.WriteByte(encodeBool(pk.SessionPresent))
nb.WriteByte(pk.ReasonCode)
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
nb.Write(pb.Bytes())
}
@@ -536,12 +543,14 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
// DisconnectEncode encodes a Disconnect packet.
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
if pk.ProtocolVersion == 5 {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -598,7 +607,8 @@ func (pk *Packet) PingrespDecode(buf []byte) error {
// PublishEncode encodes a Publish packet.
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeString(pk.TopicName)) // [MQTT-3.3.2-1]
@@ -610,16 +620,16 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
}
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
nb.Write(pb.Bytes())
}
nb.Write(pk.Payload)
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Remaining = nb.Len() + len(pk.Payload)
pk.FixedHeader.Encode(buf)
_, _ = nb.WriteTo(buf)
buf.Write(pk.Payload)
return nil
}
@@ -690,11 +700,13 @@ func (pk *Packet) PublishValidate(topicAliasMaximum uint16) Code {
// encodePubAckRelRecComp encodes a Puback, Pubrel, Pubrec, or Pubcomp packet.
func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
nb.WriteByte(pk.ReasonCode)
@@ -831,11 +843,13 @@ func (pk *Packet) ReasonCodeValid() bool {
// SubackEncode encodes a Suback packet.
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
nb.Write(pb.Bytes())
}
@@ -878,10 +892,12 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
return ErrProtocolViolationNoPacketID
}
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))
xb := bytes.NewBuffer([]byte{}) // capture and write filters after length checks
xb := mempool.GetBuffer() // capture and write filters after length checks
defer mempool.PutBuffer(xb)
for _, opts := range pk.Filters {
xb.Write(encodeString(opts.Filter)) // [MQTT-3.8.3-1]
if pk.ProtocolVersion == 5 {
@@ -892,7 +908,8 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
}
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -983,11 +1000,13 @@ func (pk *Packet) SubscribeValidate() Code {
// UnsubackEncode encodes an Unsuback packet.
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -1031,16 +1050,19 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
return ErrProtocolViolationNoPacketID
}
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.Write(encodeUint16(pk.PacketID))
xb := bytes.NewBuffer([]byte{}) // capture filters and write after length checks
xb := mempool.GetBuffer() // capture filters and write after length checks
defer mempool.PutBuffer(xb)
for _, sub := range pk.Filters {
xb.Write(encodeString(sub.Filter)) // [MQTT-3.10.3-1]
}
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -1100,10 +1122,12 @@ func (pk *Packet) UnsubscribeValidate() Code {
// AuthEncode encodes an Auth packet.
func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
nb := bytes.NewBuffer([]byte{})
nb := mempool.GetBuffer()
defer mempool.PutBuffer(nb)
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pb := mempool.GetBuffer()
defer mempool.PutBuffer(pb)
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())