diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e2ac511..79bab7f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -41,7 +41,7 @@ jobs: strategy: matrix: - go: [ "1.23", "1.24" ] + go: [ "1.24", "1.25" ] steps: - name: Checkout code uses: actions/checkout@v4 @@ -61,14 +61,14 @@ jobs: shell: bash --noprofile --norc -x -eo pipefail {0} run: | go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... --failfast -vet=off - if [ "${{ matrix.go }}" = "1.24" ]; then + if [ "${{ matrix.go }}" = "1.25" ]; then ./scripts/cov.sh CI else go test -modfile=go_test.mod -race -v -p=1 ./... --failfast -vet=off -tags=internal_testing fi - name: Coveralls - if: matrix.go == '1.24' + if: matrix.go == '1.25' uses: coverallsapp/github-action@v2 with: file: acc.out \ No newline at end of file diff --git a/jetstream/jetstream_options.go b/jetstream/jetstream_options.go index 0fffbc7..fa016ca 100644 --- a/jetstream/jetstream_options.go +++ b/jetstream/jetstream_options.go @@ -14,6 +14,7 @@ package jetstream import ( + "context" "fmt" "time" ) @@ -486,6 +487,7 @@ func FetchMaxWait(timeout time.Duration) FetchOpt { return fmt.Errorf("%w: timeout value must be greater than 0", ErrInvalidOption) } req.Expires = timeout + req.maxWaitSet = true return nil } } @@ -508,6 +510,31 @@ func FetchHeartbeat(hb time.Duration) FetchOpt { } } +// FetchContext sets a context for the Fetch operation. +// The Fetch operation will be canceled if the context is canceled. +// If the context has a deadline, it will be used to set expiry on pull request. +func FetchContext(ctx context.Context) FetchOpt { + return func(req *pullRequest) error { + req.ctx = ctx + + // If context has a deadline, use it to set expiry + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining <= 0 { + return fmt.Errorf("%w: context deadline already exceeded", ErrInvalidOption) + } + // Use 90% of remaining time for server (capped at 1s) + buffer := time.Duration(float64(remaining) * 0.1) + if buffer > time.Second { + buffer = time.Second + } + req.Expires = remaining - buffer + } + + return nil + } +} + // WithDeletedDetails can be used to display the information about messages // deleted from a stream on a stream info request func WithDeletedDetails(deletedDetails bool) StreamInfoOpt { @@ -648,3 +675,25 @@ func WithStallWait(ttl time.Duration) PublishOpt { return nil } } + +type nextOptFunc func(*nextOpts) + +func (fn nextOptFunc) configureNext(opts *nextOpts) { + fn(opts) +} + +// NextMaxWait sets a timeout for the Next operation. +// If the timeout is reached before a message is available, a timeout error is returned. +func NextMaxWait(timeout time.Duration) NextOpt { + return nextOptFunc(func(opts *nextOpts) { + opts.timeout = timeout + }) +} + +// NextContext sets a context for the Next operation. +// The Next operation will be canceled if the context is canceled. +func NextContext(ctx context.Context) NextOpt { + return nextOptFunc(func(opts *nextOpts) { + opts.ctx = ctx + }) +} diff --git a/jetstream/ordered.go b/jetstream/ordered.go index 35705a5..4d10c29 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -282,10 +282,21 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er return sub, nil } -func (s *orderedSubscription) Next() (Msg, error) { +func (s *orderedSubscription) Next(opts ...NextOpt) (Msg, error) { for { - msg, err := s.consumer.currentSub.Next() + msg, err := s.consumer.currentSub.Next(opts...) if err != nil { + // Check for errors which should be returned directly + // without resetting the consumer + if errors.Is(err, ErrInvalidOption) { + return nil, err + } + if errors.Is(err, nats.ErrTimeout) { + return nil, err + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } if errors.Is(err, ErrMsgIteratorClosed) { s.Stop() return nil, err diff --git a/jetstream/pull.go b/jetstream/pull.go index 3aacc0b..6446b95 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -14,6 +14,7 @@ package jetstream import ( + "context" "encoding/json" "errors" "fmt" @@ -34,8 +35,11 @@ type ( MessagesContext interface { // Next retrieves next message on a stream. It will block until the next // message is available. If the context is canceled, Next will return - // ErrMsgIteratorClosed error. - Next() (Msg, error) + // ErrMsgIteratorClosed error. An optional timeout or context can be + // provided using NextOpt options. If none are provided, Next will block + // indefinitely until a message is available, iterator is closed or a + // heartbeat error occurs. + Next(opts ...NextOpt) (Msg, error) // Stop unsubscribes from the stream and cancels subscription. Calling // Next after calling Stop will return ErrMsgIteratorClosed error. @@ -92,15 +96,17 @@ type ( } pullRequest struct { - Expires time.Duration `json:"expires,omitempty"` - Batch int `json:"batch,omitempty"` - MaxBytes int `json:"max_bytes,omitempty"` - NoWait bool `json:"no_wait,omitempty"` - Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` - MinPending int64 `json:"min_pending,omitempty"` - MinAckPending int64 `json:"min_ack_pending,omitempty"` - PinID string `json:"id,omitempty"` - Group string `json:"group,omitempty"` + Expires time.Duration `json:"expires,omitempty"` + Batch int `json:"batch,omitempty"` + MaxBytes int `json:"max_bytes,omitempty"` + NoWait bool `json:"no_wait,omitempty"` + Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` + MinPending int64 `json:"min_pending,omitempty"` + MinAckPending int64 `json:"min_ack_pending,omitempty"` + PinID string `json:"id,omitempty"` + Group string `json:"group,omitempty"` + ctx context.Context `json:"-"` + maxWaitSet bool `json:"-"` } consumeOpts struct { @@ -167,6 +173,16 @@ type ( timer *time.Timer sync.Mutex } + + // NextOpt is an option for configuring the behavior of MessagesContext.Next. + NextOpt interface { + configureNext(*nextOpts) + } + + nextOpts struct { + timeout time.Duration + ctx context.Context + } ) const ( @@ -569,7 +585,30 @@ var ( // Next retrieves next message on a stream. It will block until the next // message is available. If the context is canceled, Next will return // ErrMsgIteratorClosed error. -func (s *pullSubscription) Next() (Msg, error) { +func (s *pullSubscription) Next(opts ...NextOpt) (Msg, error) { + var nextOpts nextOpts + for _, opt := range opts { + opt.configureNext(&nextOpts) + } + + if nextOpts.timeout > 0 && nextOpts.ctx != nil { + return nil, fmt.Errorf("%w: cannot specify both NextMaxWait and NextContext", ErrInvalidOption) + } + + // Create timeout channel if needed + var timeoutCh <-chan time.Time + if nextOpts.timeout > 0 { + timer := time.NewTimer(nextOpts.timeout) + defer timer.Stop() + timeoutCh = timer.C + } + + // Use context if provided + var ctxDone <-chan struct{} + if nextOpts.ctx != nil { + ctxDone = nextOpts.ctx.Done() + } + s.Lock() defer s.Unlock() drainMode := s.draining.Load() == 1 @@ -660,6 +699,10 @@ func (s *pullSubscription) Next() (Msg, error) { } isConnected = false } + case <-timeoutCh: + return nil, nats.ErrTimeout + case <-ctxDone: + return nil, nextOpts.ctx.Err() } } } @@ -779,6 +822,11 @@ func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) return nil, err } } + + if req.ctx != nil && req.maxWaitSet { + return nil, fmt.Errorf("%w: cannot specify both FetchContext and FetchMaxWait", ErrInvalidOption) + } + // if heartbeat was not explicitly set, set it to 5 seconds for longer pulls // and disable it for shorter pulls if req.Heartbeat == unset { @@ -808,6 +856,11 @@ func (p *pullConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, return nil, err } } + + if req.ctx != nil && req.maxWaitSet { + return nil, fmt.Errorf("%w: cannot specify both FetchContext and FetchMaxWait", ErrInvalidOption) + } + // if heartbeat was not explicitly set, set it to 5 seconds for longer pulls // and disable it for shorter pulls if req.Heartbeat == unset { @@ -862,6 +915,13 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { var receivedMsgs, receivedBytes int hbTimer := sub.scheduleHeartbeatCheck(req.Heartbeat) + + // Use context if provided + var ctxDone <-chan struct{} + if req.ctx != nil { + ctxDone = req.ctx.Done() + } + go func(res *fetchResult) { defer sub.subscription.Unsubscribe() defer close(res.msgs) @@ -922,6 +982,12 @@ func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { res.done = true res.Unlock() return + case <-ctxDone: + res.Lock() + res.err = req.ctx.Err() + res.done = true + res.Unlock() + return } } }(res) diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index 5c05121..c9e6e5d 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "reflect" + "strings" "sync" "testing" "time" @@ -1835,7 +1836,7 @@ func TestOrderedConsumerInfo(t *testing.T) { } } -func TestOrderedConsumerNextTimeout(t *testing.T) { +func TestOrderedConsumerNextMaxWait(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) nc, err := nats.Connect(srv.ClientURL()) @@ -2122,6 +2123,111 @@ func TestOrderedConsumerConfig(t *testing.T) { } } +func TestOrderedConsumerMessagesNextWithTimeout(t *testing.T) { + t.Run("with timeout option", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx := context.Background() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer msgs.Stop() + + // timeout when no messages are available + start := time.Now() + _, err = msgs.Next(jetstream.NextMaxWait(100 * time.Millisecond)) + elapsed := time.Since(start) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error; got: %v", err) + } + if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond { + t.Fatalf("Timeout not respected; elapsed: %v", elapsed) + } + + // Publish a message and verify it can be fetched + if _, err := js.Publish(ctx, "FOO.A", []byte("msg1")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + + msg, err := msgs.Next(jetstream.NextMaxWait(1 * time.Second)) + if err != nil { + t.Fatalf("Expected to receive message, got error: %v", err) + } + if string(msg.Data()) != "msg1" { + t.Fatalf("Unexpected message data; got: %s", msg.Data()) + } + }) + + t.Run("context and timeout provided", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx := context.Background() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer msgs.Stop() + + // ctx and timeout cannot be used together + testCtx, testCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer testCancel() + + _, err = msgs.Next(jetstream.NextMaxWait(500*time.Millisecond), jetstream.NextContext(testCtx)) + if err == nil { + t.Fatal("Expected error when providing both NextMaxWait and NextContext") + } + if !errors.Is(err, jetstream.ErrInvalidOption) { + t.Fatalf("Expected ErrInvalidOption, got: %v", err) + } + if !strings.Contains(err.Error(), "cannot specify both NextMaxWait and NextContext") { + t.Fatalf("Expected specific error message, got: %v", err) + } + }) +} + func TestOrderedConsumerCloseConn(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index 1eb5c3c..d4fb8b9 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -476,6 +476,82 @@ func TestPullConsumerFetch(t *testing.T) { t.Fatalf("Expected error: %v; got: %v", jetstream.ErrInvalidOption, err) } }) + + t.Run("with context", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // pull request should expire before client timeout + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + result, err := c.Fetch(1, jetstream.FetchContext(ctx)) + if err != nil { + t.Fatalf("Unexpected error from Fetch: %v", err) + } + msg, ok := <-result.Messages() + if ok { + t.Fatalf("Expected no message, got: %v", msg) + } + if result.Error() != nil { + t.Fatalf("Unexpected error during fetch: %v", result.Error()) + } + + // Test context cancellation + ctx, cancel = context.WithCancel(context.Background()) + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + result, err = c.Fetch(1, jetstream.FetchContext(ctx)) + if err != nil { + t.Fatalf("Unexpected error from Fetch: %v", err) + } + msg = <-result.Messages() + if msg != nil { + t.Fatalf("Expected no message, got: %v", msg) + } + err = result.Error() + if !errors.Is(err, context.Canceled) { + t.Fatalf("Expected context canceled error, got: %v", err) + } + + // Test mutual exclusion with FetchMaxWait + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = c.Fetch(1, jetstream.FetchContext(ctx), jetstream.FetchMaxWait(time.Second)) + if !errors.Is(err, jetstream.ErrInvalidOption) { + t.Fatalf("Expected mutual exclusion error, got: %v", err) + } + + // Test already expired context + expiredCtx, cancel := context.WithTimeout(context.Background(), -time.Second) + defer cancel() + _, err = c.Fetch(1, jetstream.FetchContext(expiredCtx)) + if !errors.Is(err, jetstream.ErrInvalidOption) { + t.Fatalf("Expected invalid option error, got: %v", err) + } + }) } func TestPullConsumerFetchRace(t *testing.T) { @@ -3515,6 +3591,178 @@ func TestPullConsumerNext(t *testing.T) { }) } +func TestPullConsumerMessagesNextWithTimeout(t *testing.T) { + t.Run("with timeout option", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer msgs.Stop() + + // no msgs yet, should timeout + start := time.Now() + _, err = msgs.Next(jetstream.NextMaxWait(100 * time.Millisecond)) + elapsed := time.Since(start) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected timeout error; got: %v", err) + } + if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond { + t.Fatalf("Timeout not respected; elapsed: %v", elapsed) + } + + // Publish a message and verify it can be fetched + if _, err := js.Publish(ctx, "FOO.A", []byte("msg1")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + + msg, err := msgs.Next(jetstream.NextMaxWait(1 * time.Second)) + if err != nil { + t.Fatalf("Expected to receive message, got error: %v", err) + } + if string(msg.Data()) != "msg1" { + t.Fatalf("Unexpected message data; got: %s", msg.Data()) + } + }) + + t.Run("with context option", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + s, err := js.CreateStream(context.Background(), jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(context.Background(), jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer msgs.Stop() + + // context timeout + ctx1, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel2() + + start := time.Now() + _, err = msgs.Next(jetstream.NextContext(ctx1)) + elapsed := time.Since(start) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Expected context deadline exceeded error; got: %v", err) + } + if elapsed < 100*time.Millisecond || elapsed > 200*time.Millisecond { + t.Fatalf("Context timeout not respected; elapsed: %v", elapsed) + } + + // cancel context before calling Next + ctx2, cancel3 := context.WithCancel(context.Background()) + cancel3() + _, err = msgs.Next(jetstream.NextContext(ctx2)) + if !errors.Is(err, context.Canceled) { + t.Fatalf("Expected context canceled error; got: %v", err) + } + + // Publish a message and verify it can be fetched + if _, err := js.Publish(context.Background(), "FOO.A", []byte("msg1")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + + ctx3, cancel4 := context.WithTimeout(context.Background(), time.Second) + defer cancel4() + msg, err := msgs.Next(jetstream.NextContext(ctx3)) + if err != nil { + t.Fatalf("Expected to receive message, got error: %v", err) + } + if string(msg.Data()) != "msg1" { + t.Fatalf("Unexpected message data; got: %s", msg.Data()) + } + }) + + t.Run("context and timeout provided", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer msgs.Stop() + + // Test that providing both NextMaxWait and NextContext returns an error + testCtx, testCancel := context.WithTimeout(context.Background(), time.Second) + defer testCancel() + + _, err = msgs.Next(jetstream.NextMaxWait(500*time.Millisecond), jetstream.NextContext(testCtx)) + if err == nil { + t.Fatal("Expected error when providing both NextMaxWait and NextContext") + } + if !errors.Is(err, jetstream.ErrInvalidOption) { + t.Fatalf("Expected ErrInvalidOption, got: %v", err) + } + if !errors.Is(err, jetstream.ErrInvalidOption) { + t.Fatalf("Expected specific error message, got: %v", err) + } + }) +} + func TestPullConsumerConnectionClosed(t *testing.T) { t.Run("messages", func(t *testing.T) { srv := RunBasicJetStreamServer()