mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-26 20:21:12 +08:00
Packet encoding optimization (#343)
* Dynamically allocate buffer for writes if needed * Remove unused net.Buffer * Return bytes written to buffer instead of conn * Dynamic write buffer * Reduce double write of pk.Payload * Use memory pool for packet encode * Pool doesn't guarantee value between Put and Get * Add benchmark for bufpool * Fix issue #346 * Change default pool not to have size cap --------- Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
This commit is contained in:
81
mempool/bufpool.go
Normal file
81
mempool/bufpool.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bufPool = NewBuffer(0)
|
||||
|
||||
// GetBuffer takes a Buffer from the default buffer pool
|
||||
func GetBuffer() *bytes.Buffer { return bufPool.Get() }
|
||||
|
||||
// PutBuffer returns Buffer to the default buffer pool
|
||||
func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) }
|
||||
|
||||
type BufferPool interface {
|
||||
Get() *bytes.Buffer
|
||||
Put(x *bytes.Buffer)
|
||||
}
|
||||
|
||||
// NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will
|
||||
// return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If
|
||||
// max <= 0, no limit will be enforced.
|
||||
func NewBuffer(max int) BufferPool {
|
||||
if max > 0 {
|
||||
return newBufferWithCap(max)
|
||||
}
|
||||
|
||||
return newBuffer()
|
||||
}
|
||||
|
||||
// Buffer is a Buffer pool.
|
||||
type Buffer struct {
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
func newBuffer() *Buffer {
|
||||
return &Buffer{
|
||||
pool: &sync.Pool{
|
||||
New: func() any { return new(bytes.Buffer) },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get a Buffer from the pool.
|
||||
func (b *Buffer) Get() *bytes.Buffer {
|
||||
return b.pool.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
// Put the Buffer back into pool. It resets the Buffer for reuse.
|
||||
func (b *Buffer) Put(x *bytes.Buffer) {
|
||||
x.Reset()
|
||||
b.pool.Put(x)
|
||||
}
|
||||
|
||||
// BufferWithCap is a Buffer pool that
|
||||
type BufferWithCap struct {
|
||||
bp *Buffer
|
||||
max int
|
||||
}
|
||||
|
||||
func newBufferWithCap(max int) *BufferWithCap {
|
||||
return &BufferWithCap{
|
||||
bp: newBuffer(),
|
||||
max: max,
|
||||
}
|
||||
}
|
||||
|
||||
// Get a Buffer from the pool.
|
||||
func (b *BufferWithCap) Get() *bytes.Buffer {
|
||||
return b.bp.Get()
|
||||
}
|
||||
|
||||
// Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer
|
||||
// for reuse.
|
||||
func (b *BufferWithCap) Put(x *bytes.Buffer) {
|
||||
if x.Cap() > b.max {
|
||||
return
|
||||
}
|
||||
b.bp.Put(x)
|
||||
}
|
96
mempool/bufpool_test.go
Normal file
96
mempool/bufpool_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package mempool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBuffer(t *testing.T) {
|
||||
defer debug.SetGCPercent(debug.SetGCPercent(-1))
|
||||
bp := NewBuffer(1000)
|
||||
require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String())
|
||||
|
||||
bp = NewBuffer(0)
|
||||
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
|
||||
|
||||
bp = NewBuffer(-1)
|
||||
require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String())
|
||||
}
|
||||
|
||||
func TestBuffer(t *testing.T) {
|
||||
defer debug.SetGCPercent(debug.SetGCPercent(-1))
|
||||
Size := 101
|
||||
|
||||
bp := NewBuffer(0)
|
||||
buf := bp.Get()
|
||||
|
||||
for i := 0; i < Size; i++ {
|
||||
buf.WriteByte('a')
|
||||
}
|
||||
|
||||
bp.Put(buf)
|
||||
buf = bp.Get()
|
||||
require.Equal(t, 0, buf.Len())
|
||||
}
|
||||
|
||||
func TestBufferWithCap(t *testing.T) {
|
||||
defer debug.SetGCPercent(debug.SetGCPercent(-1))
|
||||
Size := 101
|
||||
bp := NewBuffer(100)
|
||||
buf := bp.Get()
|
||||
|
||||
for i := 0; i < Size; i++ {
|
||||
buf.WriteByte('a')
|
||||
}
|
||||
|
||||
bp.Put(buf)
|
||||
buf = bp.Get()
|
||||
require.Equal(t, 0, buf.Len())
|
||||
require.Equal(t, 0, buf.Cap())
|
||||
}
|
||||
|
||||
func BenchmarkBufferPool(b *testing.B) {
|
||||
bp := NewBuffer(0)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := bp.Get()
|
||||
b.WriteString("this is a test")
|
||||
bp.Put(b)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBufferPoolWithCapLarger(b *testing.B) {
|
||||
bp := NewBuffer(64 * 1024)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := bp.Get()
|
||||
b.WriteString("this is a test")
|
||||
bp.Put(b)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBufferPoolWithCapLesser(b *testing.B) {
|
||||
bp := NewBuffer(10)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := bp.Get()
|
||||
b.WriteString("this is a test")
|
||||
bp.Put(b)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBufferWithoutPool(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b := new(bytes.Buffer)
|
||||
b.WriteString("this is a test")
|
||||
_ = b
|
||||
}
|
||||
}
|
@@ -12,6 +12,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-mqtt/server/v2/mempool"
|
||||
)
|
||||
|
||||
// All valid packet types and their packet identifiers.
|
||||
@@ -298,7 +300,8 @@ func (s *Subscription) decode(b byte) {
|
||||
|
||||
// ConnectEncode encodes a connect packet.
|
||||
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeBytes(pk.Connect.ProtocolName))
|
||||
nb.WriteByte(pk.ProtocolVersion)
|
||||
|
||||
@@ -315,7 +318,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
nb.Write(encodeUint16(pk.Connect.Keepalive))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -324,7 +328,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.Connect.WillFlag {
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -493,12 +498,14 @@ func (pk *Packet) ConnectValidate() Code {
|
||||
|
||||
// ConnackEncode encodes a Connack packet.
|
||||
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.WriteByte(encodeBool(pk.SessionPresent))
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -536,12 +543,14 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
|
||||
// DisconnectEncode encodes a Disconnect packet.
|
||||
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -598,7 +607,8 @@ func (pk *Packet) PingrespDecode(buf []byte) error {
|
||||
|
||||
// PublishEncode encodes a Publish packet.
|
||||
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
|
||||
nb.Write(encodeString(pk.TopicName)) // [MQTT-3.3.2-1]
|
||||
|
||||
@@ -610,16 +620,16 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
}
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
nb.Write(pk.Payload)
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
pk.FixedHeader.Remaining = nb.Len() + len(pk.Payload)
|
||||
pk.FixedHeader.Encode(buf)
|
||||
_, _ = nb.WriteTo(buf)
|
||||
buf.Write(pk.Payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -690,11 +700,13 @@ func (pk *Packet) PublishValidate(topicAliasMaximum uint16) Code {
|
||||
|
||||
// encodePubAckRelRecComp encodes a Puback, Pubrel, Pubrec, or Pubcomp packet.
|
||||
func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
@@ -831,11 +843,13 @@ func (pk *Packet) ReasonCodeValid() bool {
|
||||
|
||||
// SubackEncode encodes a Suback packet.
|
||||
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -878,10 +892,12 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
return ErrProtocolViolationNoPacketID
|
||||
}
|
||||
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
xb := bytes.NewBuffer([]byte{}) // capture and write filters after length checks
|
||||
xb := mempool.GetBuffer() // capture and write filters after length checks
|
||||
defer mempool.PutBuffer(xb)
|
||||
for _, opts := range pk.Filters {
|
||||
xb.Write(encodeString(opts.Filter)) // [MQTT-3.8.3-1]
|
||||
if pk.ProtocolVersion == 5 {
|
||||
@@ -892,7 +908,8 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
}
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -983,11 +1000,13 @@ func (pk *Packet) SubscribeValidate() Code {
|
||||
|
||||
// UnsubackEncode encodes an Unsuback packet.
|
||||
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -1031,16 +1050,19 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
return ErrProtocolViolationNoPacketID
|
||||
}
|
||||
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.Write(encodeUint16(pk.PacketID))
|
||||
|
||||
xb := bytes.NewBuffer([]byte{}) // capture filters and write after length checks
|
||||
xb := mempool.GetBuffer() // capture filters and write after length checks
|
||||
defer mempool.PutBuffer(xb)
|
||||
for _, sub := range pk.Filters {
|
||||
xb.Write(encodeString(sub.Filter)) // [MQTT-3.10.3-1]
|
||||
}
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
@@ -1100,10 +1122,12 @@ func (pk *Packet) UnsubscribeValidate() Code {
|
||||
|
||||
// AuthEncode encodes an Auth packet.
|
||||
func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
|
||||
nb := bytes.NewBuffer([]byte{})
|
||||
nb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(nb)
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pb := mempool.GetBuffer()
|
||||
defer mempool.PutBuffer(pb)
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
|
||||
|
Reference in New Issue
Block a user