From 18005a9cde8d7bd3b16d5dd95e659526b0b0b74b Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 5 Jul 2022 19:25:53 +0200 Subject: [PATCH] ringbuffer: force size to be a power of 2 otherwise buffer is used partially when writeIndex overflows. --- client.go | 7 +++++-- pkg/ringbuffer/ringbuffer.go | 21 ++++++++++++++------- pkg/ringbuffer/ringbuffer_test.go | 20 ++++++++++++++------ server.go | 3 +++ servermulticasthandler.go | 4 +++- serversession.go | 4 ++-- 6 files changed, 41 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index 0b727c1d..a74590d1 100644 --- a/client.go +++ b/client.go @@ -318,6 +318,9 @@ func (c *Client) Start(scheme string, host string) error { if c.WriteBufferCount == 0 { c.WriteBufferCount = 256 } + if (c.WriteBufferCount & (c.WriteBufferCount - 1)) != 0 { + return fmt.Errorf("WriteBufferCount must be a power of two") + } if c.UserAgent == "" { c.UserAgent = "gortsplib" } @@ -699,9 +702,9 @@ func (c *Client) playRecordStart() { // when reading, writeBuffer is only used to send RTCP receiver reports, // that are much smaller than RTP packets and are sent at a fixed interval. // decrease RAM consumption by allocating less buffers. - c.writeBuffer = ringbuffer.New(8) + c.writeBuffer, _ = ringbuffer.New(8) } else { - c.writeBuffer = ringbuffer.New(uint64(c.WriteBufferCount)) + c.writeBuffer, _ = ringbuffer.New(uint64(c.WriteBufferCount)) } c.writerRunning = true c.writerDone = make(chan struct{}) diff --git a/pkg/ringbuffer/ringbuffer.go b/pkg/ringbuffer/ringbuffer.go index e2c63edd..9289611a 100644 --- a/pkg/ringbuffer/ringbuffer.go +++ b/pkg/ringbuffer/ringbuffer.go @@ -2,13 +2,14 @@ package ringbuffer import ( + "fmt" "sync/atomic" "unsafe" ) // RingBuffer is a ring buffer. type RingBuffer struct { - bufferSize uint64 + size uint64 readIndex uint64 writeIndex uint64 closed int64 @@ -17,14 +18,20 @@ type RingBuffer struct { } // New allocates a RingBuffer. -func New(size uint64) *RingBuffer { +func New(size uint64) (*RingBuffer, error) { + // when writeIndex overflows, if size is not a power of + // two, only a portion of the buffer is used. + if (size & (size - 1)) != 0 { + return nil, fmt.Errorf("size must be a power of two") + } + return &RingBuffer{ - bufferSize: size, + size: size, readIndex: 1, writeIndex: 0, buffer: make([]unsafe.Pointer, size), event: newEvent(), - } + }, nil } // Close makes Pull() return false. @@ -35,7 +42,7 @@ func (r *RingBuffer) Close() { // Reset restores Pull() after a Close(). func (r *RingBuffer) Reset() { - for i := uint64(0); i < r.bufferSize; i++ { + for i := uint64(0); i < r.size; i++ { atomic.SwapPointer(&r.buffer[i], nil) } atomic.SwapUint64(&r.writeIndex, 0) @@ -46,7 +53,7 @@ func (r *RingBuffer) Reset() { // Push pushes some data at the end of the buffer. func (r *RingBuffer) Push(data interface{}) { writeIndex := atomic.AddUint64(&r.writeIndex, 1) - i := writeIndex % r.bufferSize + i := writeIndex % r.size atomic.SwapPointer(&r.buffer[i], unsafe.Pointer(&data)) r.event.signal() } @@ -54,7 +61,7 @@ func (r *RingBuffer) Push(data interface{}) { // Pull pulls some data from the beginning of the buffer. func (r *RingBuffer) Pull() (interface{}, bool) { for { - i := r.readIndex % r.bufferSize + i := r.readIndex % r.size res := (*interface{})(atomic.SwapPointer(&r.buffer[i], nil)) if res == nil { if atomic.SwapInt64(&r.closed, 0) == 1 { diff --git a/pkg/ringbuffer/ringbuffer_test.go b/pkg/ringbuffer/ringbuffer_test.go index 6f02dbc4..25a9635a 100644 --- a/pkg/ringbuffer/ringbuffer_test.go +++ b/pkg/ringbuffer/ringbuffer_test.go @@ -8,8 +8,14 @@ import ( "github.com/stretchr/testify/require" ) +func TestCreateError(t *testing.T) { + _, err := New(1000) + require.EqualError(t, err, "size must be a power of two") +} + func TestPushBeforePull(t *testing.T) { - r := New(1024) + r, err := New(1024) + require.NoError(t, err) defer r.Close() data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4) @@ -21,7 +27,8 @@ func TestPushBeforePull(t *testing.T) { } func TestPullBeforePush(t *testing.T) { - r := New(1024) + r, err := New(1024) + require.NoError(t, err) defer r.Close() data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4) @@ -41,7 +48,8 @@ func TestPullBeforePush(t *testing.T) { } func TestClose(t *testing.T) { - r := New(1024) + r, err := New(1024) + require.NoError(t, err) done := make(chan struct{}) go func() { @@ -68,7 +76,7 @@ func TestClose(t *testing.T) { } func BenchmarkPushPullContinuous(b *testing.B) { - r := New(1024 * 8) + r, _ := New(1024 * 8) defer r.Close() data := make([]byte, 1024) @@ -91,7 +99,7 @@ func BenchmarkPushPullContinuous(b *testing.B) { } func BenchmarkPushPullPaused5(b *testing.B) { - r := New(128) + r, _ := New(128) defer r.Close() data := make([]byte, 1024) @@ -115,7 +123,7 @@ func BenchmarkPushPullPaused5(b *testing.B) { } func BenchmarkPushPullPaused10(b *testing.B) { - r := New(1024 * 8) + r, _ := New(1024 * 8) defer r.Close() data := make([]byte, 1024) diff --git a/server.go b/server.go index 34c000b6..dd5448fb 100644 --- a/server.go +++ b/server.go @@ -169,6 +169,9 @@ func (s *Server) Start() error { if s.WriteBufferCount == 0 { s.WriteBufferCount = 256 } + if (s.WriteBufferCount & (s.WriteBufferCount - 1)) != 0 { + return fmt.Errorf("WriteBufferCount must be a power of two") + } // system functions if s.Listen == nil { diff --git a/servermulticasthandler.go b/servermulticasthandler.go index 2b44438b..7e010c15 100644 --- a/servermulticasthandler.go +++ b/servermulticasthandler.go @@ -26,10 +26,12 @@ func newServerMulticastHandler(s *Server) (*serverMulticastHandler, error) { return nil, err } + wb, _ := ringbuffer.New(uint64(s.WriteBufferCount)) + h := &serverMulticastHandler{ rtpl: rtpl, rtcpl: rtcpl, - writeBuffer: ringbuffer.New(uint64(s.WriteBufferCount)), + writeBuffer: wb, writerDone: make(chan struct{}), } diff --git a/serversession.go b/serversession.go index 644ed5c5..074e84bf 100644 --- a/serversession.go +++ b/serversession.go @@ -824,7 +824,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base // inside the callback. if ss.state != ServerSessionStatePlay && *ss.setuppedTransport != TransportUDPMulticast { - ss.writeBuffer = ringbuffer.New(uint64(ss.s.WriteBufferCount)) + ss.writeBuffer, _ = ringbuffer.New(uint64(ss.s.WriteBufferCount)) } res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ @@ -956,7 +956,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base // when recording, writeBuffer is only used to send RTCP receiver reports, // that are much smaller than RTP packets and are sent at a fixed interval. // decrease RAM consumption by allocating less buffers. - ss.writeBuffer = ringbuffer.New(uint64(8)) + ss.writeBuffer, _ = ringbuffer.New(uint64(8)) res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ Session: ss,