ringbuffer: force size to be a power of 2

otherwise buffer is used partially when writeIndex overflows.
This commit is contained in:
aler9
2022-07-05 19:25:53 +02:00
parent ef900359ce
commit 18005a9cde
6 changed files with 41 additions and 18 deletions

View File

@@ -318,6 +318,9 @@ func (c *Client) Start(scheme string, host string) error {
if c.WriteBufferCount == 0 { if c.WriteBufferCount == 0 {
c.WriteBufferCount = 256 c.WriteBufferCount = 256
} }
if (c.WriteBufferCount & (c.WriteBufferCount - 1)) != 0 {
return fmt.Errorf("WriteBufferCount must be a power of two")
}
if c.UserAgent == "" { if c.UserAgent == "" {
c.UserAgent = "gortsplib" c.UserAgent = "gortsplib"
} }
@@ -699,9 +702,9 @@ func (c *Client) playRecordStart() {
// when reading, writeBuffer is only used to send RTCP receiver reports, // when reading, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval. // that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers. // decrease RAM consumption by allocating less buffers.
c.writeBuffer = ringbuffer.New(8) c.writeBuffer, _ = ringbuffer.New(8)
} else { } else {
c.writeBuffer = ringbuffer.New(uint64(c.WriteBufferCount)) c.writeBuffer, _ = ringbuffer.New(uint64(c.WriteBufferCount))
} }
c.writerRunning = true c.writerRunning = true
c.writerDone = make(chan struct{}) c.writerDone = make(chan struct{})

View File

@@ -2,13 +2,14 @@
package ringbuffer package ringbuffer
import ( import (
"fmt"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
) )
// RingBuffer is a ring buffer. // RingBuffer is a ring buffer.
type RingBuffer struct { type RingBuffer struct {
bufferSize uint64 size uint64
readIndex uint64 readIndex uint64
writeIndex uint64 writeIndex uint64
closed int64 closed int64
@@ -17,14 +18,20 @@ type RingBuffer struct {
} }
// New allocates a RingBuffer. // New allocates a RingBuffer.
func New(size uint64) *RingBuffer { func New(size uint64) (*RingBuffer, error) {
// when writeIndex overflows, if size is not a power of
// two, only a portion of the buffer is used.
if (size & (size - 1)) != 0 {
return nil, fmt.Errorf("size must be a power of two")
}
return &RingBuffer{ return &RingBuffer{
bufferSize: size, size: size,
readIndex: 1, readIndex: 1,
writeIndex: 0, writeIndex: 0,
buffer: make([]unsafe.Pointer, size), buffer: make([]unsafe.Pointer, size),
event: newEvent(), event: newEvent(),
} }, nil
} }
// Close makes Pull() return false. // Close makes Pull() return false.
@@ -35,7 +42,7 @@ func (r *RingBuffer) Close() {
// Reset restores Pull() after a Close(). // Reset restores Pull() after a Close().
func (r *RingBuffer) Reset() { func (r *RingBuffer) Reset() {
for i := uint64(0); i < r.bufferSize; i++ { for i := uint64(0); i < r.size; i++ {
atomic.SwapPointer(&r.buffer[i], nil) atomic.SwapPointer(&r.buffer[i], nil)
} }
atomic.SwapUint64(&r.writeIndex, 0) atomic.SwapUint64(&r.writeIndex, 0)
@@ -46,7 +53,7 @@ func (r *RingBuffer) Reset() {
// Push pushes some data at the end of the buffer. // Push pushes some data at the end of the buffer.
func (r *RingBuffer) Push(data interface{}) { func (r *RingBuffer) Push(data interface{}) {
writeIndex := atomic.AddUint64(&r.writeIndex, 1) writeIndex := atomic.AddUint64(&r.writeIndex, 1)
i := writeIndex % r.bufferSize i := writeIndex % r.size
atomic.SwapPointer(&r.buffer[i], unsafe.Pointer(&data)) atomic.SwapPointer(&r.buffer[i], unsafe.Pointer(&data))
r.event.signal() r.event.signal()
} }
@@ -54,7 +61,7 @@ func (r *RingBuffer) Push(data interface{}) {
// Pull pulls some data from the beginning of the buffer. // Pull pulls some data from the beginning of the buffer.
func (r *RingBuffer) Pull() (interface{}, bool) { func (r *RingBuffer) Pull() (interface{}, bool) {
for { for {
i := r.readIndex % r.bufferSize i := r.readIndex % r.size
res := (*interface{})(atomic.SwapPointer(&r.buffer[i], nil)) res := (*interface{})(atomic.SwapPointer(&r.buffer[i], nil))
if res == nil { if res == nil {
if atomic.SwapInt64(&r.closed, 0) == 1 { if atomic.SwapInt64(&r.closed, 0) == 1 {

View File

@@ -8,8 +8,14 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestCreateError(t *testing.T) {
_, err := New(1000)
require.EqualError(t, err, "size must be a power of two")
}
func TestPushBeforePull(t *testing.T) { func TestPushBeforePull(t *testing.T) {
r := New(1024) r, err := New(1024)
require.NoError(t, err)
defer r.Close() defer r.Close()
data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4) data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4)
@@ -21,7 +27,8 @@ func TestPushBeforePull(t *testing.T) {
} }
func TestPullBeforePush(t *testing.T) { func TestPullBeforePush(t *testing.T) {
r := New(1024) r, err := New(1024)
require.NoError(t, err)
defer r.Close() defer r.Close()
data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4) data := bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 1024/4)
@@ -41,7 +48,8 @@ func TestPullBeforePush(t *testing.T) {
} }
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
r := New(1024) r, err := New(1024)
require.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@@ -68,7 +76,7 @@ func TestClose(t *testing.T) {
} }
func BenchmarkPushPullContinuous(b *testing.B) { func BenchmarkPushPullContinuous(b *testing.B) {
r := New(1024 * 8) r, _ := New(1024 * 8)
defer r.Close() defer r.Close()
data := make([]byte, 1024) data := make([]byte, 1024)
@@ -91,7 +99,7 @@ func BenchmarkPushPullContinuous(b *testing.B) {
} }
func BenchmarkPushPullPaused5(b *testing.B) { func BenchmarkPushPullPaused5(b *testing.B) {
r := New(128) r, _ := New(128)
defer r.Close() defer r.Close()
data := make([]byte, 1024) data := make([]byte, 1024)
@@ -115,7 +123,7 @@ func BenchmarkPushPullPaused5(b *testing.B) {
} }
func BenchmarkPushPullPaused10(b *testing.B) { func BenchmarkPushPullPaused10(b *testing.B) {
r := New(1024 * 8) r, _ := New(1024 * 8)
defer r.Close() defer r.Close()
data := make([]byte, 1024) data := make([]byte, 1024)

View File

@@ -169,6 +169,9 @@ func (s *Server) Start() error {
if s.WriteBufferCount == 0 { if s.WriteBufferCount == 0 {
s.WriteBufferCount = 256 s.WriteBufferCount = 256
} }
if (s.WriteBufferCount & (s.WriteBufferCount - 1)) != 0 {
return fmt.Errorf("WriteBufferCount must be a power of two")
}
// system functions // system functions
if s.Listen == nil { if s.Listen == nil {

View File

@@ -26,10 +26,12 @@ func newServerMulticastHandler(s *Server) (*serverMulticastHandler, error) {
return nil, err return nil, err
} }
wb, _ := ringbuffer.New(uint64(s.WriteBufferCount))
h := &serverMulticastHandler{ h := &serverMulticastHandler{
rtpl: rtpl, rtpl: rtpl,
rtcpl: rtcpl, rtcpl: rtcpl,
writeBuffer: ringbuffer.New(uint64(s.WriteBufferCount)), writeBuffer: wb,
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
} }

View File

@@ -824,7 +824,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
// inside the callback. // inside the callback.
if ss.state != ServerSessionStatePlay && if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast { *ss.setuppedTransport != TransportUDPMulticast {
ss.writeBuffer = ringbuffer.New(uint64(ss.s.WriteBufferCount)) ss.writeBuffer, _ = ringbuffer.New(uint64(ss.s.WriteBufferCount))
} }
res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{
@@ -956,7 +956,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
// when recording, writeBuffer is only used to send RTCP receiver reports, // when recording, writeBuffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval. // that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers. // decrease RAM consumption by allocating less buffers.
ss.writeBuffer = ringbuffer.New(uint64(8)) ss.writeBuffer, _ = ringbuffer.New(uint64(8))
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
Session: ss, Session: ss,