[ADDED] Context and timeout options to Messages.Next() plus Fetch context support (#1938)

* Add timeout support to MessagesContext.Next() with mutual exclusion
* [ADDED] FetchContext support for pull consumer operations
---------

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
This commit is contained in:
Piotr Piotrowski
2025-09-18 13:17:21 +02:00
committed by GitHub
parent 3d0a13a355
commit 98a4735206
6 changed files with 498 additions and 18 deletions

View File

@@ -41,7 +41,7 @@ jobs:
strategy:
matrix:
go: [ "1.23", "1.24" ]
go: [ "1.24", "1.25" ]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -61,14 +61,14 @@ jobs:
shell: bash --noprofile --norc -x -eo pipefail {0}
run: |
go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... --failfast -vet=off
if [ "${{ matrix.go }}" = "1.24" ]; then
if [ "${{ matrix.go }}" = "1.25" ]; then
./scripts/cov.sh CI
else
go test -modfile=go_test.mod -race -v -p=1 ./... --failfast -vet=off -tags=internal_testing
fi
- name: Coveralls
if: matrix.go == '1.24'
if: matrix.go == '1.25'
uses: coverallsapp/github-action@v2
with:
file: acc.out

View File

@@ -14,6 +14,7 @@
package jetstream
import (
"context"
"fmt"
"time"
)
@@ -486,6 +487,7 @@ func FetchMaxWait(timeout time.Duration) FetchOpt {
return fmt.Errorf("%w: timeout value must be greater than 0", ErrInvalidOption)
}
req.Expires = timeout
req.maxWaitSet = true
return nil
}
}
@@ -508,6 +510,31 @@ func FetchHeartbeat(hb time.Duration) FetchOpt {
}
}
// FetchContext sets a context for the Fetch operation.
// The Fetch operation will be canceled if the context is canceled.
// If the context has a deadline, it will be used to set expiry on pull request.
func FetchContext(ctx context.Context) FetchOpt {
return func(req *pullRequest) error {
req.ctx = ctx
// If context has a deadline, use it to set expiry
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
if remaining <= 0 {
return fmt.Errorf("%w: context deadline already exceeded", ErrInvalidOption)
}
// Use 90% of remaining time for server (capped at 1s)
buffer := time.Duration(float64(remaining) * 0.1)
if buffer > time.Second {
buffer = time.Second
}
req.Expires = remaining - buffer
}
return nil
}
}
// WithDeletedDetails can be used to display the information about messages
// deleted from a stream on a stream info request
func WithDeletedDetails(deletedDetails bool) StreamInfoOpt {
@@ -648,3 +675,25 @@ func WithStallWait(ttl time.Duration) PublishOpt {
return nil
}
}
type nextOptFunc func(*nextOpts)
func (fn nextOptFunc) configureNext(opts *nextOpts) {
fn(opts)
}
// NextMaxWait sets a timeout for the Next operation.
// If the timeout is reached before a message is available, a timeout error is returned.
func NextMaxWait(timeout time.Duration) NextOpt {
return nextOptFunc(func(opts *nextOpts) {
opts.timeout = timeout
})
}
// NextContext sets a context for the Next operation.
// The Next operation will be canceled if the context is canceled.
func NextContext(ctx context.Context) NextOpt {
return nextOptFunc(func(opts *nextOpts) {
opts.ctx = ctx
})
}

View File

@@ -282,10 +282,21 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er
return sub, nil
}
func (s *orderedSubscription) Next() (Msg, error) {
func (s *orderedSubscription) Next(opts ...NextOpt) (Msg, error) {
for {
msg, err := s.consumer.currentSub.Next()
msg, err := s.consumer.currentSub.Next(opts...)
if err != nil {
// Check for errors which should be returned directly
// without resetting the consumer
if errors.Is(err, ErrInvalidOption) {
return nil, err
}
if errors.Is(err, nats.ErrTimeout) {
return nil, err
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
if errors.Is(err, ErrMsgIteratorClosed) {
s.Stop()
return nil, err

View File

@@ -14,6 +14,7 @@
package jetstream
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -34,8 +35,11 @@ type (
MessagesContext interface {
// Next retrieves next message on a stream. It will block until the next
// message is available. If the context is canceled, Next will return
// ErrMsgIteratorClosed error.
Next() (Msg, error)
// ErrMsgIteratorClosed error. An optional timeout or context can be
// provided using NextOpt options. If none are provided, Next will block
// indefinitely until a message is available, iterator is closed or a
// heartbeat error occurs.
Next(opts ...NextOpt) (Msg, error)
// Stop unsubscribes from the stream and cancels subscription. Calling
// Next after calling Stop will return ErrMsgIteratorClosed error.
@@ -92,15 +96,17 @@ type (
}
pullRequest struct {
Expires time.Duration `json:"expires,omitempty"`
Batch int `json:"batch,omitempty"`
MaxBytes int `json:"max_bytes,omitempty"`
NoWait bool `json:"no_wait,omitempty"`
Heartbeat time.Duration `json:"idle_heartbeat,omitempty"`
MinPending int64 `json:"min_pending,omitempty"`
MinAckPending int64 `json:"min_ack_pending,omitempty"`
PinID string `json:"id,omitempty"`
Group string `json:"group,omitempty"`
Expires time.Duration `json:"expires,omitempty"`
Batch int `json:"batch,omitempty"`
MaxBytes int `json:"max_bytes,omitempty"`
NoWait bool `json:"no_wait,omitempty"`
Heartbeat time.Duration `json:"idle_heartbeat,omitempty"`
MinPending int64 `json:"min_pending,omitempty"`
MinAckPending int64 `json:"min_ack_pending,omitempty"`
PinID string `json:"id,omitempty"`
Group string `json:"group,omitempty"`
ctx context.Context `json:"-"`
maxWaitSet bool `json:"-"`
}
consumeOpts struct {
@@ -167,6 +173,16 @@ type (
timer *time.Timer
sync.Mutex
}
// NextOpt is an option for configuring the behavior of MessagesContext.Next.
NextOpt interface {
configureNext(*nextOpts)
}
nextOpts struct {
timeout time.Duration
ctx context.Context
}
)
const (
@@ -569,7 +585,30 @@ var (
// Next retrieves next message on a stream. It will block until the next
// message is available. If the context is canceled, Next will return
// ErrMsgIteratorClosed error.
func (s *pullSubscription) Next() (Msg, error) {
func (s *pullSubscription) Next(opts ...NextOpt) (Msg, error) {
var nextOpts nextOpts
for _, opt := range opts {
opt.configureNext(&nextOpts)
}
if nextOpts.timeout > 0 && nextOpts.ctx != nil {
return nil, fmt.Errorf("%w: cannot specify both NextMaxWait and NextContext", ErrInvalidOption)
}
// Create timeout channel if needed
var timeoutCh <-chan time.Time
if nextOpts.timeout > 0 {
timer := time.NewTimer(nextOpts.timeout)
defer timer.Stop()
timeoutCh = timer.C
}
// Use context if provided
var ctxDone <-chan struct{}
if nextOpts.ctx != nil {
ctxDone = nextOpts.ctx.Done()
}
s.Lock()
defer s.Unlock()
drainMode := s.draining.Load() == 1
@@ -660,6 +699,10 @@ func (s *pullSubscription) Next() (Msg, error) {
}
isConnected = false
}
case <-timeoutCh:
return nil, nats.ErrTimeout
case <-ctxDone:
return nil, nextOpts.ctx.Err()
}
}
}
@@ -779,6 +822,11 @@ func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error)
return nil, err
}
}
if req.ctx != nil && req.maxWaitSet {
return nil, fmt.Errorf("%w: cannot specify both FetchContext and FetchMaxWait", ErrInvalidOption)
}
// if heartbeat was not explicitly set, set it to 5 seconds for longer pulls
// and disable it for shorter pulls
if req.Heartbeat == unset {
@@ -808,6 +856,11 @@ func (p *pullConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch,
return nil, err
}
}
if req.ctx != nil && req.maxWaitSet {
return nil, fmt.Errorf("%w: cannot specify both FetchContext and FetchMaxWait", ErrInvalidOption)
}
// if heartbeat was not explicitly set, set it to 5 seconds for longer pulls
// and disable it for shorter pulls
if req.Heartbeat == unset {
@@ -862,6 +915,13 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
var receivedMsgs, receivedBytes int
hbTimer := sub.scheduleHeartbeatCheck(req.Heartbeat)
// Use context if provided
var ctxDone <-chan struct{}
if req.ctx != nil {
ctxDone = req.ctx.Done()
}
go func(res *fetchResult) {
defer sub.subscription.Unsubscribe()
defer close(res.msgs)
@@ -922,6 +982,12 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
res.done = true
res.Unlock()
return
case <-ctxDone:
res.Lock()
res.err = req.ctx.Err()
res.done = true
res.Unlock()
return
}
}
}(res)

View File

@@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"reflect"
"strings"
"sync"
"testing"
"time"
@@ -1835,7 +1836,7 @@ func TestOrderedConsumerInfo(t *testing.T) {
}
}
func TestOrderedConsumerNextTimeout(t *testing.T) {
func TestOrderedConsumerNextMaxWait(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
@@ -2122,6 +2123,111 @@ func TestOrderedConsumerConfig(t *testing.T) {
}
}
func TestOrderedConsumerMessagesNextWithTimeout(t *testing.T) {
t.Run("with timeout option", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx := context.Background()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer msgs.Stop()
// timeout when no messages are available
start := time.Now()
_, err = msgs.Next(jetstream.NextMaxWait(100 * time.Millisecond))
elapsed := time.Since(start)
if !errors.Is(err, nats.ErrTimeout) {
t.Fatalf("Expected timeout error; got: %v", err)
}
if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond {
t.Fatalf("Timeout not respected; elapsed: %v", elapsed)
}
// Publish a message and verify it can be fetched
if _, err := js.Publish(ctx, "FOO.A", []byte("msg1")); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
msg, err := msgs.Next(jetstream.NextMaxWait(1 * time.Second))
if err != nil {
t.Fatalf("Expected to receive message, got error: %v", err)
}
if string(msg.Data()) != "msg1" {
t.Fatalf("Unexpected message data; got: %s", msg.Data())
}
})
t.Run("context and timeout provided", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx := context.Background()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer msgs.Stop()
// ctx and timeout cannot be used together
testCtx, testCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer testCancel()
_, err = msgs.Next(jetstream.NextMaxWait(500*time.Millisecond), jetstream.NextContext(testCtx))
if err == nil {
t.Fatal("Expected error when providing both NextMaxWait and NextContext")
}
if !errors.Is(err, jetstream.ErrInvalidOption) {
t.Fatalf("Expected ErrInvalidOption, got: %v", err)
}
if !strings.Contains(err.Error(), "cannot specify both NextMaxWait and NextContext") {
t.Fatalf("Expected specific error message, got: %v", err)
}
})
}
func TestOrderedConsumerCloseConn(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

View File

@@ -476,6 +476,82 @@ func TestPullConsumerFetch(t *testing.T) {
t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err)
}
})
t.Run("with context", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// pull request should expire before client timeout
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
defer cancel()
result, err := c.Fetch(1, jetstream.FetchContext(ctx))
if err != nil {
t.Fatalf("Unexpected error from Fetch: %v", err)
}
msg, ok := <-result.Messages()
if ok {
t.Fatalf("Expected no message, got: %v", msg)
}
if result.Error() != nil {
t.Fatalf("Unexpected error during fetch: %v", result.Error())
}
// Test context cancellation
ctx, cancel = context.WithCancel(context.Background())
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
result, err = c.Fetch(1, jetstream.FetchContext(ctx))
if err != nil {
t.Fatalf("Unexpected error from Fetch: %v", err)
}
msg = <-result.Messages()
if msg != nil {
t.Fatalf("Expected no message, got: %v", msg)
}
err = result.Error()
if !errors.Is(err, context.Canceled) {
t.Fatalf("Expected context canceled error, got: %v", err)
}
// Test mutual exclusion with FetchMaxWait
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = c.Fetch(1, jetstream.FetchContext(ctx), jetstream.FetchMaxWait(time.Second))
if !errors.Is(err, jetstream.ErrInvalidOption) {
t.Fatalf("Expected mutual exclusion error, got: %v", err)
}
// Test already expired context
expiredCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
defer cancel()
_, err = c.Fetch(1, jetstream.FetchContext(expiredCtx))
if !errors.Is(err, jetstream.ErrInvalidOption) {
t.Fatalf("Expected invalid option error, got: %v", err)
}
})
}
func TestPullConsumerFetchRace(t *testing.T) {
@@ -3515,6 +3591,178 @@ func TestPullConsumerNext(t *testing.T) {
})
}
func TestPullConsumerMessagesNextWithTimeout(t *testing.T) {
t.Run("with timeout option", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer msgs.Stop()
// no msgs yet, should timeout
start := time.Now()
_, err = msgs.Next(jetstream.NextMaxWait(100 * time.Millisecond))
elapsed := time.Since(start)
if !errors.Is(err, nats.ErrTimeout) {
t.Fatalf("Expected timeout error; got: %v", err)
}
if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond {
t.Fatalf("Timeout not respected; elapsed: %v", elapsed)
}
// Publish a message and verify it can be fetched
if _, err := js.Publish(ctx, "FOO.A", []byte("msg1")); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
msg, err := msgs.Next(jetstream.NextMaxWait(1 * time.Second))
if err != nil {
t.Fatalf("Expected to receive message, got error: %v", err)
}
if string(msg.Data()) != "msg1" {
t.Fatalf("Unexpected message data; got: %s", msg.Data())
}
})
t.Run("with context option", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
s, err := js.CreateStream(context.Background(), jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(context.Background(), jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer msgs.Stop()
// context timeout
ctx1, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel2()
start := time.Now()
_, err = msgs.Next(jetstream.NextContext(ctx1))
elapsed := time.Since(start)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("Expected context deadline exceeded error; got: %v", err)
}
if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond {
t.Fatalf("Context timeout not respected; elapsed: %v", elapsed)
}
// cancel context before calling Next
ctx2, cancel3 := context.WithCancel(context.Background())
cancel3()
_, err = msgs.Next(jetstream.NextContext(ctx2))
if !errors.Is(err, context.Canceled) {
t.Fatalf("Expected context canceled error; got: %v", err)
}
// Publish a message and verify it can be fetched
if _, err := js.Publish(context.Background(), "FOO.A", []byte("msg1")); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
ctx3, cancel4 := context.WithTimeout(context.Background(), time.Second)
defer cancel4()
msg, err := msgs.Next(jetstream.NextContext(ctx3))
if err != nil {
t.Fatalf("Expected to receive message, got error: %v", err)
}
if string(msg.Data()) != "msg1" {
t.Fatalf("Unexpected message data; got: %s", msg.Data())
}
})
t.Run("context and timeout provided", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer msgs.Stop()
// Test that providing both NextMaxWait and NextContext returns an error
testCtx, testCancel := context.WithTimeout(context.Background(), time.Second)
defer testCancel()
_, err = msgs.Next(jetstream.NextMaxWait(500*time.Millisecond), jetstream.NextContext(testCtx))
if err == nil {
t.Fatal("Expected error when providing both NextMaxWait and NextContext")
}
if !errors.Is(err, jetstream.ErrInvalidOption) {
t.Fatalf("Expected ErrInvalidOption, got: %v", err)
}
if !errors.Is(err, jetstream.ErrInvalidOption) {
t.Fatalf("Expected specific error message, got: %v", err)
}
})
}
func TestPullConsumerConnectionClosed(t *testing.T) {
t.Run("messages", func(t *testing.T) {
srv := RunBasicJetStreamServer()