feat: add WaitForWithContext (#480)

* chore: fix test timeout helper

using os.Exit(1) kills everything, tests statuses are not always displayed

* chore: refactor WaitFor unit tests

zero-code changes

* fix: WaitFor on first condition

duration must be non-zero if first conditions is true

* feat: add WaitForWithContext

* chore: provide meaningful returned values for WaitFor and WaitForWithContext
This commit is contained in:
ccoVeille
2024-07-15 19:25:42 +02:00
committed by GitHub
parent 9e343973a4
commit 0f4679bf52
4 changed files with 255 additions and 51 deletions

View File

@@ -278,6 +278,7 @@ Concurrency helpers:
- [Async](#async)
- [Transaction](#transaction)
- [WaitFor](#waitfor)
- [WaitForWithContext](#waitforwithcontext)
Error handling:
@@ -3104,9 +3105,9 @@ laterTrue := func(i int) bool {
return i > 5
}
iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
// 1
// 0ms
// 1ms
// true
iterations, duration, ok := lo.WaitFor(alwaysFalse, 10*time.Millisecond, time.Millisecond)
@@ -3125,6 +3126,49 @@ iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, 5*time.Mi
// false
```
### WaitForWithContext
Runs periodically until a condition is validated or context is invalid.
The condition receives also the context, so it can invalidate the process in the condition checker
```go
ctx := context.Background()
alwaysTrue := func(_ context.Context, i int) bool { return true }
alwaysFalse := func(_ context.Context, i int) bool { return false }
laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}
iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
// 1
// 1ms
// true
iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysFalse, 10*time.Millisecond, time.Millisecond)
// 10
// 10ms
// false
iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, time.Millisecond)
// 5
// 5ms
// true
iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, 5*time.Millisecond)
// 2
// 10ms
// false
expiringCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
iterations, duration, ok := lo.WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, time.Millisecond)
// 5
// 5.1ms
// false
```
### Validate
Helper function that creates an error when a condition is not met.

View File

@@ -1,6 +1,7 @@
package lo
import (
"context"
"sync"
"time"
)
@@ -98,33 +99,38 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A,
}
// WaitFor runs periodically until a condition is validated.
func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
if condition(0) {
return 1, 0, true
func WaitFor(condition func(i int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
conditionWithContext := func(_ context.Context, currentIteration int) bool {
return condition(currentIteration)
}
return WaitForWithContext(context.Background(), conditionWithContext, timeout, heartbeatDelay)
}
// WaitForWithContext runs periodically until a condition is validated or context is canceled.
func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, currentIteration int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
start := time.Now()
timer := time.NewTimer(maxDuration)
ticker := time.NewTicker(tick)
if ctx.Err() != nil {
return totalIterations, time.Since(start), false
}
ctx, cleanCtx := context.WithTimeout(ctx, timeout)
ticker := time.NewTicker(heartbeatDelay)
defer func() {
timer.Stop()
cleanCtx()
ticker.Stop()
}()
i := 1
for {
select {
case <-timer.C:
return i, time.Since(start), false
case <-ctx.Done():
return totalIterations, time.Since(start), false
case <-ticker.C:
if condition(i) {
return i + 1, time.Since(start), true
totalIterations++
if condition(ctx, totalIterations-1) {
return totalIterations, time.Since(start), true
}
i++
}
}
}

View File

@@ -1,6 +1,7 @@
package lo
import (
"context"
"sync"
"testing"
"time"
@@ -215,44 +216,198 @@ func TestAsyncX(t *testing.T) {
func TestWaitFor(t *testing.T) {
t.Parallel()
testWithTimeout(t, 100*time.Millisecond)
is := assert.New(t)
alwaysTrue := func(i int) bool { return true }
alwaysFalse := func(i int) bool { return false }
testTimeout := 100 * time.Millisecond
longTimeout := 2 * testTimeout
shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter)
is.Equal(time.Duration(0), duration)
is.True(ok)
iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 4*time.Millisecond)
is.Equal(3, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
t.Run("exist condition works", func(t *testing.T) {
t.Parallel()
laterTrue := func(i int) bool {
return i >= 5
}
testWithTimeout(t, testTimeout)
is := assert.New(t)
iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(6, iter)
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, 5*time.Millisecond)
is.Equal(2, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
laterTrue := func(i int) bool {
return i >= 5
}
counter := 0
iter, duration, ok := WaitFor(laterTrue, longTimeout, time.Millisecond)
is.Equal(6, iter, "unexpected iteration count")
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
})
alwaysFalse = func(i int) bool {
is.Equal(counter, i)
counter++
return false
}
t.Run("counter is incremented", func(t *testing.T) {
t.Parallel()
iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 1050*time.Microsecond)
is.Equal(10, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
testWithTimeout(t, testTimeout)
is := assert.New(t)
counter := 0
alwaysFalse := func(i int) bool {
is.Equal(counter, i)
counter++
return false
}
iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 1050*time.Microsecond)
is.Equal(counter, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})
alwaysTrue := func(_ int) bool { return true }
alwaysFalse := func(_ int) bool { return false }
t.Run("short timeout works", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})
t.Run("timeout works", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})
t.Run("exist on first condition", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter, "unexpected iteration count")
is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
is.True(ok)
})
}
func TestWaitForWithContext(t *testing.T) {
t.Parallel()
testTimeout := 100 * time.Millisecond
longTimeout := 2 * testTimeout
shortTimeout := 4 * time.Millisecond
t.Run("exist condition works", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}
iter, duration, ok := WaitForWithContext(context.Background(), laterTrue, longTimeout, time.Millisecond)
is.Equal(6, iter, "unexpected iteration count")
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
})
t.Run("counter is incremented", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
counter := 0
alwaysFalse := func(_ context.Context, i int) bool {
is.Equal(counter, i)
counter++
return false
}
iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 1050*time.Microsecond)
is.Equal(counter, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})
alwaysTrue := func(_ context.Context, _ int) bool { return true }
alwaysFalse := func(_ context.Context, _ int) bool { return false }
t.Run("short timeout works", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})
t.Run("timeout works", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})
t.Run("exist on first condition", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
iter, duration, ok := WaitForWithContext(context.Background(), alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter, "unexpected iteration count")
is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
is.True(ok)
})
t.Run("context cancellation stops everything", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
expiringCtx, clean := context.WithTimeout(context.Background(), 8*time.Millisecond)
t.Cleanup(func() {
clean()
})
iter, duration, ok := WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, 3*time.Millisecond)
is.Equal(2, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})
t.Run("canceled context stops everything", func(t *testing.T) {
t.Parallel()
testWithTimeout(t, testTimeout)
is := assert.New(t)
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
iter, duration, ok := WaitForWithContext(canceledCtx, alwaysFalse, 100*time.Millisecond, 1050*time.Microsecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(1*time.Millisecond, duration, float64(5*time.Microsecond))
is.False(ok)
})
}

View File

@@ -1,7 +1,6 @@
package lo
import (
"os"
"testing"
"time"
)
@@ -18,7 +17,7 @@ func testWithTimeout(t *testing.T, timeout time.Duration) {
case <-testFinished:
case <-time.After(timeout):
t.Errorf("test timed out after %s", timeout)
os.Exit(1)
t.FailNow()
}
}()
}