From 04a3c45f602fb391d650940f5bba18f1e3788366 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Sat, 26 Aug 2023 17:23:54 +0200 Subject: [PATCH] ringbuffer: discard pending data when buffer is closed (#387) --- pkg/ringbuffer/ringbuffer.go | 9 ++++++++ pkg/ringbuffer/ringbuffer_test.go | 36 +++++++++++++++---------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/pkg/ringbuffer/ringbuffer.go b/pkg/ringbuffer/ringbuffer.go index 31f8c90e..364aea31 100644 --- a/pkg/ringbuffer/ringbuffer.go +++ b/pkg/ringbuffer/ringbuffer.go @@ -38,7 +38,14 @@ func New(size uint64) (*RingBuffer, error) { // Close makes Pull() return false. func (r *RingBuffer) Close() { r.mutex.Lock() + r.closed = true + + // discard pending data to make Pull() exit immediately + for i := uint64(0); i < r.size; i++ { + r.buffer[i] = nil + } + r.mutex.Unlock() r.cond.Broadcast() } @@ -48,6 +55,7 @@ func (r *RingBuffer) Reset() { for i := uint64(0); i < r.size; i++ { r.buffer[i] = nil } + r.writeIndex = 0 r.readIndex = 0 r.closed = false @@ -64,6 +72,7 @@ func (r *RingBuffer) Push(data interface{}) bool { r.buffer[r.writeIndex] = data r.writeIndex = (r.writeIndex + 1) % r.size + r.mutex.Unlock() r.cond.Broadcast() diff --git a/pkg/ringbuffer/ringbuffer_test.go b/pkg/ringbuffer/ringbuffer_test.go index 1c20a5f7..ff553096 100644 --- a/pkg/ringbuffer/ringbuffer_test.go +++ b/pkg/ringbuffer/ringbuffer_test.go @@ -51,30 +51,28 @@ func TestClose(t *testing.T) { r, err := New(1024) require.NoError(t, err) - done := make(chan struct{}) - go func() { - defer close(done) - - _, ok := r.Pull() - require.Equal(t, true, ok) - - _, ok = r.Pull() - require.Equal(t, false, ok) - }() - ok := r.Push([]byte{1, 2, 3, 4}) require.Equal(t, true, ok) - r.Close() - <-done - - r.Reset() - - ok = r.Push([]byte{5, 6, 7, 8}) - require.Equal(t, true, ok) - _, ok = r.Pull() require.Equal(t, true, ok) + + ok = r.Push([]byte{5, 6, 7, 8}) + require.Equal(t, true, ok) + + r.Close() + + _, ok = r.Pull() + require.Equal(t, false, ok) + + r.Reset() + + ok = r.Push([]byte{9, 10, 11, 12}) + require.Equal(t, true, ok) + + data, ok := r.Pull() + require.Equal(t, true, ok) + require.Equal(t, []byte{9, 10, 11, 12}, data) } func TestOverflow(t *testing.T) {