diff --git a/sync.go b/sync.go index 9ed630a..59220e0 100644 --- a/sync.go +++ b/sync.go @@ -602,3 +602,120 @@ func (m *FIFOMutex) Unlock() { m.waiting[0].L.Unlock() m.waiting = m.waiting[1:] } + +// BufferedBatcher is a Chan-like object that: +// - processes all added items in the provided callback as a batch so that they're all processed together +// - doesn't block when adding an item while a batch is being processed but add it to the next batch +// - if an item is added several times to the same batch, it will be processed only once in the next batch +type BufferedBatcher struct { + batch map[any]bool // Locked by c's mutex + c *sync.Cond + cancel context.CancelFunc + ctx context.Context + mc sync.Mutex // Locks cancel and ctx + onBatch BufferedBatcherOnBatchFunc +} + +type BufferedBatcherOnBatchFunc func(ctx context.Context, batch []any) + +type BufferedBatcherOptions struct { + OnBatch BufferedBatcherOnBatchFunc +} + +func NewBufferedBatcher(o BufferedBatcherOptions) *BufferedBatcher { + return &BufferedBatcher{ + batch: make(map[any]bool), + c: sync.NewCond(&sync.Mutex{}), + onBatch: o.OnBatch, + } +} + +func (bb *BufferedBatcher) Start(ctx context.Context) { + // Already running + bb.mc.Lock() + if bb.ctx != nil && bb.ctx.Err() == nil { + bb.mc.Unlock() + return + } + + // Create context + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Store context + bb.ctx = ctx + bb.cancel = cancel + bb.mc.Unlock() + + // Handle context + go func() { + // Wait for context to be done + <-ctx.Done() + + // Signal + bb.c.L.Lock() + bb.c.Signal() + bb.c.L.Unlock() + }() + + // Loop + for { + // Context has been canceled + if ctx.Err() != nil { + return + } + + // Wait for batch + bb.c.L.Lock() + if len(bb.batch) == 0 { + bb.c.Wait() + bb.c.L.Unlock() + continue + } + + // Copy batch into a slice + var batch []any + for i := range bb.batch { + batch = append(batch, i) + } + + // Reset batch + bb.batch = map[any]bool{} + + // Unlock + bb.c.L.Unlock() + + // Callback + bb.onBatch(ctx, batch) + } +} + +func (bb *BufferedBatcher) Add(i any) { + // Lock + bb.c.L.Lock() + defer bb.c.L.Unlock() + + // Store + bb.batch[i] = true + + // Signal + bb.c.Signal() +} + +func (bb *BufferedBatcher) Stop() { + // Lock + bb.mc.Lock() + defer bb.mc.Unlock() + + // Not running + if bb.ctx == nil { + return + } + + // Cancel + bb.cancel() + + // Reset context + bb.ctx = nil + bb.cancel = nil +} diff --git a/sync_test.go b/sync_test.go index fa7ebb3..aa60d99 100644 --- a/sync_test.go +++ b/sync_test.go @@ -2,6 +2,7 @@ package astikit import ( "context" + "errors" "fmt" "reflect" "strings" @@ -185,10 +186,10 @@ func TestDebugMutex(t *testing.T) { if e, g := 1, len(ss); e != g { t.Fatalf("expected %d, got %d", e, g) } - if s, g := "sync_test.go:176", ss[0]; !strings.Contains(g, s) { + if s, g := "sync_test.go:177", ss[0]; !strings.Contains(g, s) { t.Fatalf("%s doesn't contain %s", g, s) } - if s, g := "sync_test.go:181", ss[0]; !strings.Contains(g, s) { + if s, g := "sync_test.go:182", ss[0]; !strings.Contains(g, s) { t.Fatalf("%s doesn't contain %s", g, s) } } @@ -218,3 +219,83 @@ func testFIFOMutex(i int, m *FIFOMutex, r *[]int, wg *sync.WaitGroup) { m.Unlock() }() } + +func TestBufferedBatcher(t *testing.T) { + var count int + var batches []map[any]int + var bb1 *BufferedBatcher + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + bb1 = NewBufferedBatcher(BufferedBatcherOptions{OnBatch: func(ctx context.Context, batch []any) { + count++ + if len(batch) > 0 { + m := make(map[any]int) + for _, i := range batch { + m[i]++ + } + batches = append(batches, m) + } + switch count { + case 1: + bb1.Add(1) + bb1.Add(1) + bb1.Add(2) + case 2: + bb1.Add(2) + bb1.Add(2) + bb1.Add(3) + case 3: + bb1.Add(1) + bb1.Add(1) + bb1.Add(2) + bb1.Add(2) + bb1.Add(3) + bb1.Add(3) + case 4: + go func() { + time.Sleep(100 * time.Millisecond) + bb1.Add(1) + }() + case 5: + cancel1() + } + }}) + bb1.Add(1) + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + go func() { + defer cancel2() + bb1.Start(ctx1) + }() + <-ctx2.Done() + if errors.Is(ctx2.Err(), context.DeadlineExceeded) { + t.Fatal("expected nothing, got timeout") + } + if e, g := []map[any]int{ + {1: 1}, + {1: 1, 2: 1}, + {2: 1, 3: 1}, + {1: 1, 2: 1, 3: 1}, + {1: 1}, + }, batches; !reflect.DeepEqual(e, g) { + t.Fatalf("expected %+v, got %+v", e, g) + } + + var bb2 *BufferedBatcher + bb2 = NewBufferedBatcher(BufferedBatcherOptions{OnBatch: func(ctx context.Context, batch []any) { + bb2.Start(context.Background()) + bb2.Stop() + bb2.Stop() + }}) + bb2.Add(1) + ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second) + defer cancel3() + go func() { + defer cancel3() + bb2.Start(context.Background()) + }() + <-ctx3.Done() + if errors.Is(ctx3.Err(), context.DeadlineExceeded) { + t.Fatal("expected nothing, got timeout") + } +}