mirror of
https://github.com/nats-io/nats.go.git
synced 2025-09-26 20:41:41 +08:00

This fixes an issue where a deadlock could occur when calling `Stop()` or `Drain()` on `ConsumeContext` or `MessagesContext` and then calling `Consume` or `Messages` immediately. Switched to using a type-safe implementation of `sync.Map` for subscriptions map instead of locking the whole consumer state. Additionally, changed the type of atomic flags from `uint32` to `atomic.UInt32` to avoid accidental non-atomic reads/writes. Signed-off-by: Piotr Piotrowski <piotr@synadia.com> --------- Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
984 lines
25 KiB
Go
984 lines
25 KiB
Go
// Copyright 2022-2024 The NATS Authors
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package jetstream
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
"github.com/nats-io/nats.go/internal/syncx"
|
|
"github.com/nats-io/nuid"
|
|
)
|
|
|
|
type (
|
|
// MessagesContext supports iterating over a messages on a stream.
|
|
// It is returned by [Consumer.Messages] method.
|
|
MessagesContext interface {
|
|
// Next retrieves next message on a stream. It will block until the next
|
|
// message is available. If the context is canceled, Next will return
|
|
// ErrMsgIteratorClosed error.
|
|
Next() (Msg, error)
|
|
|
|
// Stop unsubscribes from the stream and cancels subscription. Calling
|
|
// Next after calling Stop will return ErrMsgIteratorClosed error.
|
|
// All messages that are already in the buffer are discarded.
|
|
Stop()
|
|
|
|
// Drain unsubscribes from the stream and cancels subscription. All
|
|
// messages that are already in the buffer will be available on
|
|
// subsequent calls to Next. After the buffer is drained, Next will
|
|
// return ErrMsgIteratorClosed error.
|
|
Drain()
|
|
}
|
|
|
|
// ConsumeContext supports processing incoming messages from a stream.
|
|
// It is returned by [Consumer.Consume] method.
|
|
ConsumeContext interface {
|
|
// Stop unsubscribes from the stream and cancels subscription.
|
|
// No more messages will be received after calling this method.
|
|
// All messages that are already in the buffer are discarded.
|
|
Stop()
|
|
|
|
// Drain unsubscribes from the stream and cancels subscription.
|
|
// All messages that are already in the buffer will be processed in callback function.
|
|
Drain()
|
|
}
|
|
|
|
// MessageHandler is a handler function used as callback in [Consume].
|
|
MessageHandler func(msg Msg)
|
|
|
|
// PullConsumeOpt represent additional options used in [Consume] for pull consumers.
|
|
PullConsumeOpt interface {
|
|
configureConsume(*consumeOpts) error
|
|
}
|
|
|
|
// PullMessagesOpt represent additional options used in [Messages] for pull consumers.
|
|
PullMessagesOpt interface {
|
|
configureMessages(*consumeOpts) error
|
|
}
|
|
|
|
pullConsumer struct {
|
|
sync.Mutex
|
|
jetStream *jetStream
|
|
stream string
|
|
durable bool
|
|
name string
|
|
info *ConsumerInfo
|
|
subs syncx.Map[string, *pullSubscription]
|
|
}
|
|
|
|
pullRequest struct {
|
|
Expires time.Duration `json:"expires,omitempty"`
|
|
Batch int `json:"batch,omitempty"`
|
|
MaxBytes int `json:"max_bytes,omitempty"`
|
|
NoWait bool `json:"no_wait,omitempty"`
|
|
Heartbeat time.Duration `json:"idle_heartbeat,omitempty"`
|
|
}
|
|
|
|
consumeOpts struct {
|
|
Expires time.Duration
|
|
MaxMessages int
|
|
MaxBytes int
|
|
Heartbeat time.Duration
|
|
ErrHandler ConsumeErrHandlerFunc
|
|
ReportMissingHeartbeats bool
|
|
ThresholdMessages int
|
|
ThresholdBytes int
|
|
StopAfter int
|
|
stopAfterMsgsLeft chan int
|
|
notifyOnReconnect bool
|
|
}
|
|
|
|
ConsumeErrHandlerFunc func(consumeCtx ConsumeContext, err error)
|
|
|
|
pullSubscription struct {
|
|
sync.Mutex
|
|
id string
|
|
consumer *pullConsumer
|
|
subscription *nats.Subscription
|
|
msgs chan *nats.Msg
|
|
errs chan error
|
|
pending pendingMsgs
|
|
hbMonitor *hbMonitor
|
|
fetchInProgress atomic.Uint32
|
|
closed atomic.Uint32
|
|
draining atomic.Uint32
|
|
done chan struct{}
|
|
connStatusChanged chan nats.Status
|
|
fetchNext chan *pullRequest
|
|
consumeOpts *consumeOpts
|
|
delivered int
|
|
}
|
|
|
|
pendingMsgs struct {
|
|
msgCount int
|
|
byteCount int
|
|
}
|
|
|
|
MessageBatch interface {
|
|
Messages() <-chan Msg
|
|
Error() error
|
|
}
|
|
|
|
fetchResult struct {
|
|
msgs chan Msg
|
|
err error
|
|
done bool
|
|
sseq uint64
|
|
}
|
|
|
|
FetchOpt func(*pullRequest) error
|
|
|
|
hbMonitor struct {
|
|
timer *time.Timer
|
|
sync.Mutex
|
|
}
|
|
)
|
|
|
|
const (
|
|
DefaultMaxMessages = 500
|
|
DefaultExpires = 30 * time.Second
|
|
unset = -1
|
|
)
|
|
|
|
func min(x, y int) int {
|
|
if x < y {
|
|
return x
|
|
}
|
|
return y
|
|
}
|
|
|
|
// Consume can be used to continuously receive messages and handle them
|
|
// with the provided callback function. Consume cannot be used concurrently
|
|
// when using ordered consumer.
|
|
//
|
|
// See [Consumer.Consume] for more details.
|
|
func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (ConsumeContext, error) {
|
|
if handler == nil {
|
|
return nil, ErrHandlerRequired
|
|
}
|
|
consumeOpts, err := parseConsumeOpts(false, opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err)
|
|
}
|
|
p.Lock()
|
|
|
|
subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name))
|
|
|
|
consumeID := nuid.Next()
|
|
sub := &pullSubscription{
|
|
id: consumeID,
|
|
consumer: p,
|
|
errs: make(chan error, 1),
|
|
done: make(chan struct{}, 1),
|
|
fetchNext: make(chan *pullRequest, 1),
|
|
consumeOpts: consumeOpts,
|
|
}
|
|
sub.connStatusChanged = p.jetStream.conn.StatusChanged(nats.CONNECTED, nats.RECONNECTING)
|
|
|
|
sub.hbMonitor = sub.scheduleHeartbeatCheck(consumeOpts.Heartbeat)
|
|
|
|
p.subs.Store(sub.id, sub)
|
|
p.Unlock()
|
|
|
|
internalHandler := func(msg *nats.Msg) {
|
|
if sub.hbMonitor != nil {
|
|
sub.hbMonitor.Stop()
|
|
}
|
|
userMsg, msgErr := checkMsg(msg)
|
|
if !userMsg && msgErr == nil {
|
|
if sub.hbMonitor != nil {
|
|
sub.hbMonitor.Reset(2 * consumeOpts.Heartbeat)
|
|
}
|
|
return
|
|
}
|
|
defer func() {
|
|
sub.Lock()
|
|
sub.checkPending()
|
|
if sub.hbMonitor != nil {
|
|
sub.hbMonitor.Reset(2 * consumeOpts.Heartbeat)
|
|
}
|
|
sub.Unlock()
|
|
}()
|
|
if !userMsg {
|
|
// heartbeat message
|
|
if msgErr == nil {
|
|
return
|
|
}
|
|
|
|
sub.Lock()
|
|
err := sub.handleStatusMsg(msg, msgErr)
|
|
sub.Unlock()
|
|
|
|
if err != nil {
|
|
if sub.closed.Load() == 1 {
|
|
return
|
|
}
|
|
if sub.consumeOpts.ErrHandler != nil {
|
|
sub.consumeOpts.ErrHandler(sub, err)
|
|
}
|
|
sub.Stop()
|
|
}
|
|
return
|
|
}
|
|
handler(p.jetStream.toJSMsg(msg))
|
|
sub.Lock()
|
|
sub.decrementPendingMsgs(msg)
|
|
sub.incrementDeliveredMsgs()
|
|
sub.Unlock()
|
|
|
|
if sub.consumeOpts.StopAfter > 0 && sub.consumeOpts.StopAfter == sub.delivered {
|
|
sub.Stop()
|
|
}
|
|
}
|
|
inbox := p.jetStream.conn.NewInbox()
|
|
sub.subscription, err = p.jetStream.conn.Subscribe(inbox, internalHandler)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sub.subscription.SetClosedHandler(func(sid string) func(string) {
|
|
return func(subject string) {
|
|
p.subs.Delete(sid)
|
|
sub.draining.CompareAndSwap(1, 0)
|
|
}
|
|
}(sub.id))
|
|
|
|
sub.Lock()
|
|
// initial pull
|
|
sub.resetPendingMsgs()
|
|
batchSize := sub.consumeOpts.MaxMessages
|
|
if sub.consumeOpts.StopAfter > 0 {
|
|
batchSize = min(batchSize, sub.consumeOpts.StopAfter-sub.delivered)
|
|
}
|
|
if err := sub.pull(&pullRequest{
|
|
Expires: consumeOpts.Expires,
|
|
Batch: batchSize,
|
|
MaxBytes: consumeOpts.MaxBytes,
|
|
Heartbeat: consumeOpts.Heartbeat,
|
|
}, subject); err != nil {
|
|
sub.errs <- err
|
|
}
|
|
sub.Unlock()
|
|
|
|
go func() {
|
|
isConnected := true
|
|
for {
|
|
if sub.closed.Load() == 1 {
|
|
return
|
|
}
|
|
select {
|
|
case status, ok := <-sub.connStatusChanged:
|
|
if !ok {
|
|
continue
|
|
}
|
|
if status == nats.RECONNECTING {
|
|
if sub.hbMonitor != nil {
|
|
sub.hbMonitor.Stop()
|
|
}
|
|
isConnected = false
|
|
}
|
|
if status == nats.CONNECTED {
|
|
sub.Lock()
|
|
if !isConnected {
|
|
isConnected = true
|
|
if sub.consumeOpts.notifyOnReconnect {
|
|
sub.errs <- errConnected
|
|
}
|
|
|
|
sub.fetchNext <- &pullRequest{
|
|
Expires: sub.consumeOpts.Expires,
|
|
Batch: sub.consumeOpts.MaxMessages,
|
|
MaxBytes: sub.consumeOpts.MaxBytes,
|
|
Heartbeat: sub.consumeOpts.Heartbeat,
|
|
}
|
|
if sub.hbMonitor != nil {
|
|
sub.hbMonitor.Reset(2 * sub.consumeOpts.Heartbeat)
|
|
}
|
|
sub.resetPendingMsgs()
|
|
}
|
|
sub.Unlock()
|
|
}
|
|
case err := <-sub.errs:
|
|
sub.Lock()
|
|
if sub.consumeOpts.ErrHandler != nil {
|
|
sub.consumeOpts.ErrHandler(sub, err)
|
|
}
|
|
if errors.Is(err, ErrNoHeartbeat) {
|
|
batchSize := sub.consumeOpts.MaxMessages
|
|
if sub.consumeOpts.StopAfter > 0 {
|
|
batchSize = min(batchSize, sub.consumeOpts.StopAfter-sub.delivered)
|
|
}
|
|
sub.fetchNext <- &pullRequest{
|
|
Expires: sub.consumeOpts.Expires,
|
|
Batch: batchSize,
|
|
MaxBytes: sub.consumeOpts.MaxBytes,
|
|
Heartbeat: sub.consumeOpts.Heartbeat,
|
|
}
|
|
if sub.hbMonitor != nil {
|
|
sub.hbMonitor.Reset(2 * sub.consumeOpts.Heartbeat)
|
|
}
|
|
sub.resetPendingMsgs()
|
|
}
|
|
sub.Unlock()
|
|
case <-sub.done:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
go sub.pullMessages(subject)
|
|
|
|
return sub, nil
|
|
}
|
|
|
|
// resetPendingMsgs resets pending message count and byte count
|
|
// to the values set in consumeOpts
|
|
// lock should be held before calling this method
|
|
func (s *pullSubscription) resetPendingMsgs() {
|
|
s.pending.msgCount = s.consumeOpts.MaxMessages
|
|
s.pending.byteCount = s.consumeOpts.MaxBytes
|
|
}
|
|
|
|
// decrementPendingMsgs decrements pending message count and byte count
|
|
// lock should be held before calling this method
|
|
func (s *pullSubscription) decrementPendingMsgs(msg *nats.Msg) {
|
|
s.pending.msgCount--
|
|
if s.consumeOpts.MaxBytes != 0 {
|
|
s.pending.byteCount -= msg.Size()
|
|
}
|
|
}
|
|
|
|
// incrementDeliveredMsgs increments delivered message count
|
|
// lock should be held before calling this method
|
|
func (s *pullSubscription) incrementDeliveredMsgs() {
|
|
s.delivered++
|
|
}
|
|
|
|
// checkPending verifies whether there are enough messages in
|
|
// the buffer to trigger a new pull request.
|
|
// lock should be held before calling this method
|
|
func (s *pullSubscription) checkPending() {
|
|
if (s.pending.msgCount < s.consumeOpts.ThresholdMessages ||
|
|
(s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0)) &&
|
|
s.fetchInProgress.Load() == 0 {
|
|
|
|
var batchSize, maxBytes int
|
|
if s.consumeOpts.MaxBytes == 0 {
|
|
// if using messages, calculate appropriate batch size
|
|
batchSize = s.consumeOpts.MaxMessages - s.pending.msgCount
|
|
} else {
|
|
// if using bytes, use the max value
|
|
batchSize = s.consumeOpts.MaxMessages
|
|
maxBytes = s.consumeOpts.MaxBytes - s.pending.byteCount
|
|
}
|
|
if s.consumeOpts.StopAfter > 0 {
|
|
batchSize = min(batchSize, s.consumeOpts.StopAfter-s.delivered-s.pending.msgCount)
|
|
}
|
|
if batchSize > 0 {
|
|
s.fetchNext <- &pullRequest{
|
|
Expires: s.consumeOpts.Expires,
|
|
Batch: batchSize,
|
|
MaxBytes: maxBytes,
|
|
Heartbeat: s.consumeOpts.Heartbeat,
|
|
}
|
|
|
|
s.pending.msgCount = s.consumeOpts.MaxMessages
|
|
s.pending.byteCount = s.consumeOpts.MaxBytes
|
|
}
|
|
}
|
|
}
|
|
|
|
// Messages returns MessagesContext, allowing continuously iterating
|
|
// over messages on a stream. Messages cannot be used concurrently
|
|
// when using ordered consumer.
|
|
//
|
|
// See [Consumer.Messages] for more details.
|
|
func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error) {
|
|
consumeOpts, err := parseMessagesOpts(false, opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err)
|
|
}
|
|
|
|
p.Lock()
|
|
subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name))
|
|
|
|
msgs := make(chan *nats.Msg, consumeOpts.MaxMessages)
|
|
|
|
consumeID := nuid.Next()
|
|
sub := &pullSubscription{
|
|
id: consumeID,
|
|
consumer: p,
|
|
done: make(chan struct{}, 1),
|
|
msgs: msgs,
|
|
errs: make(chan error, 1),
|
|
fetchNext: make(chan *pullRequest, 1),
|
|
consumeOpts: consumeOpts,
|
|
}
|
|
sub.connStatusChanged = p.jetStream.conn.StatusChanged(nats.CONNECTED, nats.RECONNECTING)
|
|
inbox := p.jetStream.conn.NewInbox()
|
|
sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs)
|
|
if err != nil {
|
|
p.Unlock()
|
|
return nil, err
|
|
}
|
|
sub.subscription.SetClosedHandler(func(sid string) func(string) {
|
|
return func(subject string) {
|
|
if sub.draining.Load() != 1 {
|
|
// if we're not draining, subscription can be closed as soon
|
|
// as closed handler is called
|
|
// otherwise, we need to wait until all messages are drained
|
|
// in Next
|
|
p.subs.Delete(sid)
|
|
}
|
|
close(msgs)
|
|
}
|
|
}(sub.id))
|
|
|
|
p.subs.Store(sub.id, sub)
|
|
p.Unlock()
|
|
|
|
go sub.pullMessages(subject)
|
|
|
|
go func() {
|
|
for {
|
|
select {
|
|
case status, ok := <-sub.connStatusChanged:
|
|
if !ok {
|
|
return
|
|
}
|
|
if status == nats.CONNECTED {
|
|
sub.errs <- errConnected
|
|
}
|
|
if status == nats.RECONNECTING {
|
|
sub.errs <- errDisconnected
|
|
}
|
|
case <-sub.done:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return sub, nil
|
|
}
|
|
|
|
var (
|
|
errConnected = errors.New("connected")
|
|
errDisconnected = errors.New("disconnected")
|
|
)
|
|
|
|
// Next retrieves next message on a stream. It will block until the next
|
|
// message is available. If the context is canceled, Next will return
|
|
// ErrMsgIteratorClosed error.
|
|
func (s *pullSubscription) Next() (Msg, error) {
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
drainMode := s.draining.Load() == 1
|
|
closed := s.closed.Load() == 1
|
|
if closed && !drainMode {
|
|
return nil, ErrMsgIteratorClosed
|
|
}
|
|
hbMonitor := s.scheduleHeartbeatCheck(2 * s.consumeOpts.Heartbeat)
|
|
defer func() {
|
|
if hbMonitor != nil {
|
|
hbMonitor.Stop()
|
|
}
|
|
}()
|
|
|
|
isConnected := true
|
|
if s.consumeOpts.StopAfter > 0 && s.delivered >= s.consumeOpts.StopAfter {
|
|
s.Stop()
|
|
return nil, ErrMsgIteratorClosed
|
|
}
|
|
|
|
for {
|
|
s.checkPending()
|
|
select {
|
|
case msg, ok := <-s.msgs:
|
|
if !ok {
|
|
// if msgs channel is closed, it means that subscription was either drained or stopped
|
|
s.consumer.subs.Delete(s.id)
|
|
s.draining.CompareAndSwap(1, 0)
|
|
return nil, ErrMsgIteratorClosed
|
|
}
|
|
if hbMonitor != nil {
|
|
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
|
|
}
|
|
userMsg, msgErr := checkMsg(msg)
|
|
if !userMsg {
|
|
// heartbeat message
|
|
if msgErr == nil {
|
|
continue
|
|
}
|
|
if err := s.handleStatusMsg(msg, msgErr); err != nil {
|
|
s.Stop()
|
|
return nil, err
|
|
}
|
|
continue
|
|
}
|
|
s.decrementPendingMsgs(msg)
|
|
s.incrementDeliveredMsgs()
|
|
return s.consumer.jetStream.toJSMsg(msg), nil
|
|
case err := <-s.errs:
|
|
if errors.Is(err, ErrNoHeartbeat) {
|
|
s.pending.msgCount = 0
|
|
s.pending.byteCount = 0
|
|
if s.consumeOpts.ReportMissingHeartbeats {
|
|
return nil, err
|
|
}
|
|
if hbMonitor != nil {
|
|
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
|
|
}
|
|
}
|
|
if errors.Is(err, errConnected) {
|
|
if !isConnected {
|
|
isConnected = true
|
|
|
|
if s.consumeOpts.notifyOnReconnect {
|
|
return nil, errConnected
|
|
}
|
|
s.pending.msgCount = 0
|
|
s.pending.byteCount = 0
|
|
if hbMonitor != nil {
|
|
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
|
|
}
|
|
}
|
|
}
|
|
if errors.Is(err, errDisconnected) {
|
|
if hbMonitor != nil {
|
|
hbMonitor.Stop()
|
|
}
|
|
isConnected = false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error {
|
|
if !errors.Is(msgErr, nats.ErrTimeout) && !errors.Is(msgErr, ErrMaxBytesExceeded) {
|
|
if errors.Is(msgErr, ErrConsumerDeleted) || errors.Is(msgErr, ErrBadRequest) {
|
|
return msgErr
|
|
}
|
|
if s.consumeOpts.ErrHandler != nil {
|
|
s.consumeOpts.ErrHandler(s, msgErr)
|
|
}
|
|
if errors.Is(msgErr, ErrConsumerLeadershipChanged) {
|
|
s.pending.msgCount = 0
|
|
s.pending.byteCount = 0
|
|
}
|
|
return nil
|
|
}
|
|
msgsLeft, bytesLeft, err := parsePending(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.pending.msgCount -= msgsLeft
|
|
if s.pending.msgCount < 0 {
|
|
s.pending.msgCount = 0
|
|
}
|
|
if s.consumeOpts.MaxBytes > 0 {
|
|
s.pending.byteCount -= bytesLeft
|
|
if s.pending.byteCount < 0 {
|
|
s.pending.byteCount = 0
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hb *hbMonitor) Stop() {
|
|
hb.Mutex.Lock()
|
|
hb.timer.Stop()
|
|
hb.Mutex.Unlock()
|
|
}
|
|
|
|
func (hb *hbMonitor) Reset(dur time.Duration) {
|
|
hb.Mutex.Lock()
|
|
hb.timer.Reset(dur)
|
|
hb.Mutex.Unlock()
|
|
}
|
|
|
|
// Stop unsubscribes from the stream and cancels subscription. Calling
|
|
// Next after calling Stop will return ErrMsgIteratorClosed error.
|
|
// All messages that are already in the buffer are discarded.
|
|
func (s *pullSubscription) Stop() {
|
|
if !s.closed.CompareAndSwap(0, 1) {
|
|
return
|
|
}
|
|
close(s.done)
|
|
if s.consumeOpts.stopAfterMsgsLeft != nil {
|
|
if s.delivered >= s.consumeOpts.StopAfter {
|
|
close(s.consumeOpts.stopAfterMsgsLeft)
|
|
} else {
|
|
s.consumeOpts.stopAfterMsgsLeft <- s.consumeOpts.StopAfter - s.delivered
|
|
}
|
|
}
|
|
}
|
|
|
|
// Drain unsubscribes from the stream and cancels subscription. All
|
|
// messages that are already in the buffer will be available on
|
|
// subsequent calls to Next. After the buffer is drained, Next will
|
|
// return ErrMsgIteratorClosed error.
|
|
func (s *pullSubscription) Drain() {
|
|
if !s.closed.CompareAndSwap(0, 1) {
|
|
return
|
|
}
|
|
s.draining.Store(1)
|
|
close(s.done)
|
|
if s.consumeOpts.stopAfterMsgsLeft != nil {
|
|
if s.delivered >= s.consumeOpts.StopAfter {
|
|
close(s.consumeOpts.stopAfterMsgsLeft)
|
|
} else {
|
|
s.consumeOpts.stopAfterMsgsLeft <- s.consumeOpts.StopAfter - s.delivered
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fetch sends a single request to retrieve given number of messages.
|
|
// It will wait up to provided expiry time if not all messages are available.
|
|
func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) {
|
|
req := &pullRequest{
|
|
Batch: batch,
|
|
Expires: DefaultExpires,
|
|
Heartbeat: unset,
|
|
}
|
|
for _, opt := range opts {
|
|
if err := opt(req); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
// if heartbeat was not explicitly set, set it to 5 seconds for longer pulls
|
|
// and disable it for shorter pulls
|
|
if req.Heartbeat == unset {
|
|
if req.Expires >= 10*time.Second {
|
|
req.Heartbeat = 5 * time.Second
|
|
} else {
|
|
req.Heartbeat = 0
|
|
}
|
|
}
|
|
if req.Expires < 2*req.Heartbeat {
|
|
return nil, fmt.Errorf("%w: expiry time should be at least 2 times the heartbeat", ErrInvalidOption)
|
|
}
|
|
|
|
return p.fetch(req)
|
|
}
|
|
|
|
// FetchBytes is used to retrieve up to a provided bytes from the stream.
|
|
func (p *pullConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, error) {
|
|
req := &pullRequest{
|
|
Batch: 1000000,
|
|
MaxBytes: maxBytes,
|
|
Expires: DefaultExpires,
|
|
Heartbeat: unset,
|
|
}
|
|
for _, opt := range opts {
|
|
if err := opt(req); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
// if heartbeat was not explicitly set, set it to 5 seconds for longer pulls
|
|
// and disable it for shorter pulls
|
|
if req.Heartbeat == unset {
|
|
if req.Expires >= 10*time.Second {
|
|
req.Heartbeat = 5 * time.Second
|
|
} else {
|
|
req.Heartbeat = 0
|
|
}
|
|
}
|
|
if req.Expires < 2*req.Heartbeat {
|
|
return nil, fmt.Errorf("%w: expiry time should be at least 2 times the heartbeat", ErrInvalidOption)
|
|
}
|
|
|
|
return p.fetch(req)
|
|
}
|
|
|
|
// FetchNoWait sends a single request to retrieve given number of messages.
|
|
// FetchNoWait will only return messages that are available at the time of the
|
|
// request. It will not wait for more messages to arrive.
|
|
func (p *pullConsumer) FetchNoWait(batch int) (MessageBatch, error) {
|
|
req := &pullRequest{
|
|
Batch: batch,
|
|
NoWait: true,
|
|
}
|
|
|
|
return p.fetch(req)
|
|
}
|
|
|
|
func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
|
|
res := &fetchResult{
|
|
msgs: make(chan Msg, req.Batch),
|
|
}
|
|
msgs := make(chan *nats.Msg, 2*req.Batch)
|
|
subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name))
|
|
|
|
sub := &pullSubscription{
|
|
consumer: p,
|
|
done: make(chan struct{}, 1),
|
|
msgs: msgs,
|
|
errs: make(chan error, 1),
|
|
}
|
|
inbox := p.jetStream.conn.NewInbox()
|
|
var err error
|
|
sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := sub.pull(req, subject); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var receivedMsgs, receivedBytes int
|
|
hbTimer := sub.scheduleHeartbeatCheck(req.Heartbeat)
|
|
go func(res *fetchResult) {
|
|
defer sub.subscription.Unsubscribe()
|
|
defer close(res.msgs)
|
|
for {
|
|
select {
|
|
case msg := <-msgs:
|
|
p.Lock()
|
|
if hbTimer != nil {
|
|
hbTimer.Reset(2 * req.Heartbeat)
|
|
}
|
|
userMsg, err := checkMsg(msg)
|
|
if err != nil {
|
|
errNotTimeoutOrNoMsgs := !errors.Is(err, nats.ErrTimeout) && !errors.Is(err, ErrNoMessages)
|
|
if errNotTimeoutOrNoMsgs && !errors.Is(err, ErrMaxBytesExceeded) {
|
|
res.err = err
|
|
}
|
|
res.done = true
|
|
p.Unlock()
|
|
return
|
|
}
|
|
if !userMsg {
|
|
p.Unlock()
|
|
continue
|
|
}
|
|
res.msgs <- p.jetStream.toJSMsg(msg)
|
|
meta, err := msg.Metadata()
|
|
if err != nil {
|
|
res.err = fmt.Errorf("parsing message metadata: %s", err)
|
|
}
|
|
res.sseq = meta.Sequence.Stream
|
|
receivedMsgs++
|
|
if req.MaxBytes != 0 {
|
|
receivedBytes += msg.Size()
|
|
}
|
|
if receivedMsgs == req.Batch || (req.MaxBytes != 0 && receivedBytes >= req.MaxBytes) {
|
|
res.done = true
|
|
p.Unlock()
|
|
return
|
|
}
|
|
p.Unlock()
|
|
case err := <-sub.errs:
|
|
res.err = err
|
|
res.done = true
|
|
return
|
|
case <-time.After(req.Expires + 1*time.Second):
|
|
res.done = true
|
|
return
|
|
}
|
|
}
|
|
}(res)
|
|
return res, nil
|
|
}
|
|
|
|
func (fr *fetchResult) Messages() <-chan Msg {
|
|
return fr.msgs
|
|
}
|
|
|
|
func (fr *fetchResult) Error() error {
|
|
return fr.err
|
|
}
|
|
|
|
// Next is used to retrieve the next message from the stream. This
|
|
// method will block until the message is retrieved or timeout is
|
|
// reached.
|
|
func (p *pullConsumer) Next(opts ...FetchOpt) (Msg, error) {
|
|
res, err := p.Fetch(1, opts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
msg := <-res.Messages()
|
|
if msg != nil {
|
|
return msg, nil
|
|
}
|
|
if res.Error() == nil {
|
|
return nil, nats.ErrTimeout
|
|
}
|
|
return nil, res.Error()
|
|
}
|
|
|
|
func (s *pullSubscription) pullMessages(subject string) {
|
|
for {
|
|
select {
|
|
case req := <-s.fetchNext:
|
|
s.fetchInProgress.Store(1)
|
|
|
|
if err := s.pull(req, subject); err != nil {
|
|
if errors.Is(err, ErrMsgIteratorClosed) {
|
|
s.cleanup()
|
|
return
|
|
}
|
|
s.errs <- err
|
|
}
|
|
s.fetchInProgress.Store(0)
|
|
case <-s.done:
|
|
s.cleanup()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor {
|
|
if dur == 0 {
|
|
return nil
|
|
}
|
|
return &hbMonitor{
|
|
timer: time.AfterFunc(2*dur, func() {
|
|
s.errs <- ErrNoHeartbeat
|
|
}),
|
|
}
|
|
}
|
|
|
|
func (s *pullSubscription) cleanup() {
|
|
// For now this function does not need to hold the lock.
|
|
// Holding the lock here might cause a deadlock if Next()
|
|
// is already holding the lock and waiting.
|
|
// The fields that are read (subscription, hbMonitor)
|
|
// are read only (Only written on creation of pullSubscription).
|
|
if s.subscription == nil || !s.subscription.IsValid() {
|
|
return
|
|
}
|
|
if s.hbMonitor != nil {
|
|
s.hbMonitor.Stop()
|
|
}
|
|
drainMode := s.draining.Load() == 1
|
|
if drainMode {
|
|
s.subscription.Drain()
|
|
} else {
|
|
s.subscription.Unsubscribe()
|
|
}
|
|
s.closed.Store(1)
|
|
}
|
|
|
|
// pull sends a pull request to the server and waits for messages using a subscription from [pullSubscription].
|
|
// Messages will be fetched up to given batch_size or until there are no more messages or timeout is returned
|
|
func (s *pullSubscription) pull(req *pullRequest, subject string) error {
|
|
s.consumer.Lock()
|
|
defer s.consumer.Unlock()
|
|
if s.closed.Load() == 1 {
|
|
return ErrMsgIteratorClosed
|
|
}
|
|
if req.Batch < 1 {
|
|
return fmt.Errorf("%w: batch size must be at least 1", nats.ErrInvalidArg)
|
|
}
|
|
reqJSON, err := json.Marshal(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
reply := s.subscription.Subject
|
|
if err := s.consumer.jetStream.conn.PublishRequest(subject, reply, reqJSON); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseConsumeOpts(ordered bool, opts ...PullConsumeOpt) (*consumeOpts, error) {
|
|
consumeOpts := &consumeOpts{
|
|
MaxMessages: unset,
|
|
MaxBytes: unset,
|
|
Expires: DefaultExpires,
|
|
Heartbeat: unset,
|
|
ReportMissingHeartbeats: true,
|
|
StopAfter: unset,
|
|
}
|
|
for _, opt := range opts {
|
|
if err := opt.configureConsume(consumeOpts); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if err := consumeOpts.setDefaults(ordered); err != nil {
|
|
return nil, err
|
|
}
|
|
return consumeOpts, nil
|
|
}
|
|
|
|
func parseMessagesOpts(ordered bool, opts ...PullMessagesOpt) (*consumeOpts, error) {
|
|
consumeOpts := &consumeOpts{
|
|
MaxMessages: unset,
|
|
MaxBytes: unset,
|
|
Expires: DefaultExpires,
|
|
Heartbeat: unset,
|
|
ReportMissingHeartbeats: true,
|
|
StopAfter: unset,
|
|
}
|
|
for _, opt := range opts {
|
|
if err := opt.configureMessages(consumeOpts); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if err := consumeOpts.setDefaults(ordered); err != nil {
|
|
return nil, err
|
|
}
|
|
return consumeOpts, nil
|
|
}
|
|
|
|
func (consumeOpts *consumeOpts) setDefaults(ordered bool) error {
|
|
if consumeOpts.MaxBytes != unset && consumeOpts.MaxMessages != unset {
|
|
return fmt.Errorf("only one of MaxMessages and MaxBytes can be specified")
|
|
}
|
|
if consumeOpts.MaxBytes != unset {
|
|
// when max_bytes is used, set batch size to a very large number
|
|
consumeOpts.MaxMessages = 1000000
|
|
} else if consumeOpts.MaxMessages != unset {
|
|
consumeOpts.MaxBytes = 0
|
|
} else {
|
|
if consumeOpts.MaxBytes == unset {
|
|
consumeOpts.MaxBytes = 0
|
|
}
|
|
if consumeOpts.MaxMessages == unset {
|
|
consumeOpts.MaxMessages = DefaultMaxMessages
|
|
}
|
|
}
|
|
|
|
if consumeOpts.ThresholdMessages == 0 {
|
|
consumeOpts.ThresholdMessages = int(math.Ceil(float64(consumeOpts.MaxMessages) / 2))
|
|
}
|
|
if consumeOpts.ThresholdBytes == 0 {
|
|
consumeOpts.ThresholdBytes = int(math.Ceil(float64(consumeOpts.MaxBytes) / 2))
|
|
}
|
|
if consumeOpts.Heartbeat == unset {
|
|
if ordered {
|
|
consumeOpts.Heartbeat = 5 * time.Second
|
|
if consumeOpts.Expires < 10*time.Second {
|
|
consumeOpts.Heartbeat = consumeOpts.Expires / 2
|
|
}
|
|
} else {
|
|
consumeOpts.Heartbeat = consumeOpts.Expires / 2
|
|
if consumeOpts.Heartbeat > 30*time.Second {
|
|
consumeOpts.Heartbeat = 30 * time.Second
|
|
}
|
|
}
|
|
}
|
|
if consumeOpts.Heartbeat > consumeOpts.Expires/2 {
|
|
return fmt.Errorf("the value of Heartbeat must be less than 50%% of expiry")
|
|
}
|
|
return nil
|
|
}
|