Files
nats.go/jetstream/pull.go
Piotr Piotrowski e7ab93ecb8 Add ordered consumer, FetchBytes and Next, rework options
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
2023-05-23 12:03:02 +02:00

823 lines
22 KiB
Go

// Copyright 2020-2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package jetstream
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"sync"
"sync/atomic"
"time"
"github.com/nats-io/nats.go"
"github.com/nats-io/nuid"
)
type (
// MessagesContext supports iterating over a messages on a stream.
MessagesContext interface {
// Next retreives nest message on a stream. It will block until the next message is available.
Next() (Msg, error)
// Stop closes the iterator and cancels subscription.
Stop()
}
ConsumeContext interface {
Stop()
}
// MessageHandler is a handler function used as callback in [Consume]
MessageHandler func(msg Msg)
// PullConsumeOpt represent additional options used in [Consume] for pull consumers
PullConsumeOpt interface {
configureConsume(*consumeOpts) error
}
// PullMessagesOpt represent additional options used in [Messages] for pull consumers
PullMessagesOpt interface {
configureMessages(*consumeOpts) error
}
pullConsumer struct {
sync.Mutex
jetStream *jetStream
stream string
durable bool
name string
info *ConsumerInfo
subscriptions map[string]*pullSubscription
}
pullRequest struct {
Expires time.Duration `json:"expires,omitempty"`
Batch int `json:"batch,omitempty"`
MaxBytes int `json:"max_bytes,omitempty"`
NoWait bool `json:"no_wait,omitempty"`
Heartbeat time.Duration `json:"idle_heartbeat,omitempty"`
}
consumeOpts struct {
Expires time.Duration
MaxMessages int
MaxBytes int
Heartbeat time.Duration
ErrHandler ConsumeErrHandlerFunc
ReportMissingHeartbeats bool
ThresholdMessages int
ThresholdBytes int
}
ConsumeErrHandlerFunc func(consumeCtx ConsumeContext, err error)
pullSubscription struct {
sync.Mutex
id string
consumer *pullConsumer
subscription *nats.Subscription
msgs chan *nats.Msg
errs chan error
pending pendingMsgs
hbMonitor *hbMonitor
fetchInProgress uint32
closed uint32
done chan struct{}
connected chan struct{}
disconnected chan struct{}
fetchNext chan *pullRequest
consumeOpts *consumeOpts
}
pendingMsgs struct {
msgCount int
byteCount int
}
MessageBatch interface {
Messages() <-chan Msg
Error() error
}
fetchResult struct {
msgs chan Msg
err error
done bool
sseq uint64
}
FetchOpt func(*pullRequest) error
hbMonitor struct {
timer *time.Timer
sync.Mutex
}
)
const (
DefaultMaxMessages = 500
DefaultExpires = 30 * time.Second
DefaultHeartbeat = 5 * time.Second
unset = -1
)
// Consume returns a ConsumeContext, allowing for processing incoming messages from a stream in a given callback function.
//
// Available options:
// [ConsumeMaxMessages] - sets maximum number of messages stored in a buffer, default is set to 100
// [ConsumeMaxBytes] - sets maximum number of bytes stored in a buffer
// [ConsumeExpiry] - sets a timeout for individual batch request, default is set to 30 seconds
// [ConsumeHeartbeat] - sets an idle heartbeat setting for a pull request, default is set to 5s
// [ConsumeErrHandler] - sets custom consume error callback handler
// [ConsumeThresholdMessages] - sets the byte count on which Consume will trigger new pull request to the server
// [ConsumeThresholdBytes] - sets the message count on which Consume will trigger new pull request to the server
func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (ConsumeContext, error) {
if handler == nil {
return nil, ErrHandlerRequired
}
consumeOpts, err := parseConsumeOpts(opts...)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err)
}
p.Lock()
subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name))
// for single consume, use empty string as id
// this is useful for ordered consumer, where only a single subscription is valid
var consumeID string
if len(p.subscriptions) > 0 {
consumeID = nuid.Next()
}
sub := &pullSubscription{
id: consumeID,
consumer: p,
errs: make(chan error, 1),
done: make(chan struct{}, 1),
fetchNext: make(chan *pullRequest, 1),
connected: make(chan struct{}),
disconnected: make(chan struct{}),
consumeOpts: consumeOpts,
}
p.jetStream.conn.RegisterStatusChangeListener(nats.CONNECTED, sub.connected)
p.jetStream.conn.RegisterStatusChangeListener(nats.DISCONNECTED, sub.disconnected)
p.jetStream.conn.RegisterStatusChangeListener(nats.RECONNECTING, sub.disconnected)
sub.hbMonitor = sub.scheduleHeartbeatCheck(consumeOpts.Heartbeat)
p.subscriptions[sub.id] = sub
p.Unlock()
internalHandler := func(msg *nats.Msg) {
if sub.hbMonitor != nil {
sub.hbMonitor.Reset(2 * consumeOpts.Heartbeat)
}
userMsg, msgErr := checkMsg(msg)
if !userMsg && msgErr == nil {
return
}
defer func() {
if sub.pending.msgCount < consumeOpts.ThresholdMessages ||
(sub.pending.byteCount < consumeOpts.ThresholdBytes && sub.consumeOpts.MaxBytes != 0) &&
atomic.LoadUint32(&sub.fetchInProgress) == 1 {
sub.fetchNext <- &pullRequest{
Expires: sub.consumeOpts.Expires,
Batch: sub.consumeOpts.MaxMessages - sub.pending.msgCount,
MaxBytes: sub.consumeOpts.MaxBytes - sub.pending.byteCount,
Heartbeat: sub.consumeOpts.Heartbeat,
}
sub.resetPendingMsgs()
}
}()
if !userMsg {
// heartbeat message
if msgErr == nil {
return
}
if err := sub.handleStatusMsg(msg, msgErr); err != nil {
if atomic.LoadUint32(&sub.closed) == 1 {
return
}
if sub.consumeOpts.ErrHandler != nil {
sub.consumeOpts.ErrHandler(sub, err)
}
sub.Stop()
}
return
}
handler(p.jetStream.toJSMsg(msg))
sub.decrementPendingMsgs(msg)
}
inbox := nats.NewInbox()
sub.subscription, err = p.jetStream.conn.Subscribe(inbox, internalHandler)
if err != nil {
return nil, err
}
// initial pull
sub.resetPendingMsgs()
if err := sub.pull(&pullRequest{
Expires: consumeOpts.Expires,
Batch: consumeOpts.MaxMessages,
MaxBytes: consumeOpts.MaxBytes,
Heartbeat: consumeOpts.Heartbeat,
}, subject); err != nil {
sub.errs <- err
}
go func() {
isConnected := true
for {
if atomic.LoadUint32(&sub.closed) == 1 {
return
}
select {
case <-sub.disconnected:
if sub.hbMonitor != nil {
sub.hbMonitor.Stop()
}
isConnected = false
case <-sub.connected:
if !isConnected {
// try fetching consumer info several times to make sure consumer is available after reconnect
for i := 0; i < 5; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := p.Info(ctx)
cancel()
if err == nil {
break
}
if err != nil {
if i == 4 {
sub.cleanupSubscriptionAndRestoreConnHandler()
if sub.consumeOpts.ErrHandler != nil {
sub.consumeOpts.ErrHandler(sub, err)
}
return
}
}
time.Sleep(5 * time.Second)
}
sub.fetchNext <- &pullRequest{
Expires: sub.consumeOpts.Expires,
Batch: sub.consumeOpts.MaxMessages,
MaxBytes: sub.consumeOpts.MaxBytes,
Heartbeat: sub.consumeOpts.Heartbeat,
}
sub.resetPendingMsgs()
isConnected = true
}
case err := <-sub.errs:
if sub.consumeOpts.ErrHandler != nil {
sub.consumeOpts.ErrHandler(sub, err)
}
if errors.Is(err, ErrNoHeartbeat) {
sub.fetchNext <- &pullRequest{
Expires: sub.consumeOpts.Expires,
Batch: sub.consumeOpts.MaxMessages,
MaxBytes: sub.consumeOpts.MaxBytes,
Heartbeat: sub.consumeOpts.Heartbeat,
}
sub.resetPendingMsgs()
}
}
}
}()
go sub.pullMessages(subject)
return sub, nil
}
func (s *pullSubscription) resetPendingMsgs() {
s.Lock()
defer s.Unlock()
s.pending.msgCount = s.consumeOpts.MaxMessages
s.pending.byteCount = s.consumeOpts.MaxBytes
}
func (s *pullSubscription) decrementPendingMsgs(msg *nats.Msg) {
s.Lock()
defer s.Unlock()
s.pending.msgCount--
if s.consumeOpts.MaxBytes != 0 {
s.pending.byteCount -= msgSize(msg)
}
}
// Messages returns MessagesContext, allowing continuously iterating over messages on a stream.
//
// Available options:
// [ConsumeMaxMessages] - sets maximum number of messages stored in a buffer, default is set to 100
// [ConsumeMaxBytes] - sets maximum number of bytes stored in a buffer
// [ConsumeExpiry] - sets a timeout for individual batch request, default is set to 30 seconds
// [ConsumeHeartbeat] - sets an idle heartbeat setting for a pull request, default is set to 5s
// [ConsumeErrHandler] - sets custom consume error callback handler
// [ConsumeThresholdMessages] - sets the byte count on which Consume will trigger new pull request to the server
// [ConsumeThresholdBytes] - sets the message count on which Consume will trigger new pull request to the server
func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error) {
consumeOpts, err := parseMessagesOpts(opts...)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err)
}
p.Lock()
subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name))
msgs := make(chan *nats.Msg, consumeOpts.MaxMessages)
// for single consume, use empty string as id
// this is useful for ordered consumer, where only a single subscription is valid
var consumeID string
if len(p.subscriptions) > 0 {
consumeID = nuid.Next()
}
sub := &pullSubscription{
id: consumeID,
consumer: p,
done: make(chan struct{}, 1),
msgs: msgs,
errs: make(chan error, 1),
fetchNext: make(chan *pullRequest, 1),
connected: make(chan struct{}),
disconnected: make(chan struct{}),
consumeOpts: consumeOpts,
}
p.jetStream.conn.RegisterStatusChangeListener(nats.CONNECTED, sub.connected)
p.jetStream.conn.RegisterStatusChangeListener(nats.DISCONNECTED, sub.disconnected)
p.jetStream.conn.RegisterStatusChangeListener(nats.RECONNECTING, sub.disconnected)
inbox := nats.NewInbox()
sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs)
if err != nil {
p.Unlock()
return nil, err
}
go func() {
<-sub.done
sub.cleanupSubscriptionAndRestoreConnHandler()
}()
p.subscriptions[sub.id] = sub
p.Unlock()
go sub.pullMessages(subject)
return sub, nil
}
func (s *pullSubscription) Next() (Msg, error) {
s.Lock()
defer s.Unlock()
if atomic.LoadUint32(&s.closed) == 1 {
return nil, ErrMsgIteratorClosed
}
hbMonitor := s.scheduleHeartbeatCheck(s.consumeOpts.Heartbeat)
defer func() {
if hbMonitor != nil {
hbMonitor.Stop()
}
}()
isConnected := true
for {
if s.pending.msgCount < s.consumeOpts.ThresholdMessages ||
(s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0) &&
atomic.LoadUint32(&s.fetchInProgress) == 1 {
s.fetchNext <- &pullRequest{
Expires: s.consumeOpts.Expires,
Batch: s.consumeOpts.MaxMessages - s.pending.msgCount,
MaxBytes: s.consumeOpts.MaxBytes - s.pending.byteCount,
Heartbeat: s.consumeOpts.Heartbeat,
}
s.pending.msgCount = s.consumeOpts.MaxMessages
if s.consumeOpts.MaxBytes > 0 {
s.pending.byteCount = s.consumeOpts.MaxBytes
}
}
select {
case msg := <-s.msgs:
if hbMonitor != nil {
hbMonitor.Reset(2 * s.consumeOpts.Heartbeat)
}
userMsg, msgErr := checkMsg(msg)
if !userMsg {
// heartbeat message
if msgErr == nil {
continue
}
if err := s.handleStatusMsg(msg, msgErr); err != nil {
s.Stop()
return nil, err
}
continue
}
s.pending.msgCount--
if s.consumeOpts.MaxBytes > 0 {
s.pending.byteCount -= msgSize(msg)
}
return s.consumer.jetStream.toJSMsg(msg), nil
case <-s.disconnected:
if hbMonitor != nil {
hbMonitor.Stop()
}
isConnected = false
case <-s.connected:
if !isConnected {
// try fetching consumer info several times to make sure consumer is available after reconnect
for i := 0; i < 5; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := s.consumer.Info(ctx)
cancel()
if err == nil {
break
}
if err != nil {
if i == 4 {
s.Stop()
return nil, err
}
}
time.Sleep(5 * time.Second)
}
s.pending.msgCount = 0
s.pending.byteCount = 0
hbMonitor = s.scheduleHeartbeatCheck(s.consumeOpts.Heartbeat)
}
case err := <-s.errs:
if errors.Is(err, ErrNoHeartbeat) {
s.pending.msgCount = 0
s.pending.byteCount = 0
if s.consumeOpts.ReportMissingHeartbeats {
return nil, err
}
}
}
}
}
func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error {
if !errors.Is(msgErr, nats.ErrTimeout) && !errors.Is(msgErr, ErrMaxBytesExceeded) {
if s.consumeOpts.ErrHandler != nil {
s.consumeOpts.ErrHandler(s, msgErr)
}
if errors.Is(msgErr, ErrConsumerDeleted) || errors.Is(msgErr, ErrBadRequest) {
return msgErr
}
if errors.Is(msgErr, ErrConsumerLeadershipChanged) {
s.pending.msgCount = 0
s.pending.byteCount = 0
}
return nil
}
msgsLeft, bytesLeft, err := parsePending(msg)
if err != nil {
if s.consumeOpts.ErrHandler != nil {
s.consumeOpts.ErrHandler(s, err)
}
}
s.pending.msgCount -= msgsLeft
if s.pending.msgCount < 0 {
s.pending.msgCount = 0
}
if s.consumeOpts.MaxBytes > 0 {
s.pending.byteCount -= bytesLeft
if s.pending.byteCount < 0 {
s.pending.byteCount = 0
}
}
return nil
}
func (hb *hbMonitor) Stop() {
hb.Mutex.Lock()
hb.timer.Stop()
hb.Mutex.Unlock()
}
func (hb *hbMonitor) Reset(dur time.Duration) {
hb.Mutex.Lock()
hb.timer.Reset(dur)
hb.Mutex.Unlock()
}
func (s *pullSubscription) Stop() {
if atomic.LoadUint32(&s.closed) == 1 {
return
}
close(s.done)
atomic.StoreUint32(&s.closed, 1)
}
// Fetch sends a single request to retrieve given number of messages.
// It will wait up to provided expiry time if not all messages are available.
func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) {
req := &pullRequest{
Batch: batch,
Expires: DefaultExpires,
}
for _, opt := range opts {
if err := opt(req); err != nil {
return nil, err
}
}
// for longer pulls, set heartbeat value
if req.Expires >= 10*time.Second {
req.Heartbeat = 5 * time.Second
}
return p.fetch(req)
}
// FetchBytes is used to retrieve up to a provided bytes from the stream.
// This method will always send a single request and wait until provided number of bytes is
// exceeded or request times out.
func (p *pullConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, error) {
req := &pullRequest{
Batch: 1000000,
MaxBytes: maxBytes,
Expires: DefaultExpires,
}
for _, opt := range opts {
if err := opt(req); err != nil {
return nil, err
}
}
// for longer pulls, set heartbeat value
if req.Expires >= 10*time.Second {
req.Heartbeat = 5 * time.Second
}
return p.fetch(req)
}
// Fetch sends a single request to retrieve given number of messages.
// If there are any messages available at the time of sending request,
// FetchNoWait will return immediately.
func (p *pullConsumer) FetchNoWait(batch int) (MessageBatch, error) {
req := &pullRequest{
Batch: batch,
NoWait: true,
}
return p.fetch(req)
}
func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) {
res := &fetchResult{
msgs: make(chan Msg, req.Batch),
}
msgs := make(chan *nats.Msg, 2*req.Batch)
subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name))
sub := &pullSubscription{
consumer: p,
done: make(chan struct{}, 1),
msgs: msgs,
errs: make(chan error, 1),
}
inbox := nats.NewInbox()
var err error
sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs)
if err != nil {
return nil, err
}
if err := sub.pull(req, subject); err != nil {
return nil, err
}
var receivedMsgs, receivedBytes int
hbTimer := sub.scheduleHeartbeatCheck(req.Heartbeat)
go func(res *fetchResult) {
defer sub.subscription.Unsubscribe()
defer close(res.msgs)
for {
if receivedMsgs == req.Batch || (req.MaxBytes != 0 && receivedBytes == req.MaxBytes) {
res.done = true
return
}
select {
case msg := <-msgs:
if hbTimer != nil {
hbTimer.Reset(2 * req.Heartbeat)
}
userMsg, err := checkMsg(msg)
if err != nil {
if !errors.Is(err, nats.ErrTimeout) && !errors.Is(err, ErrNoMessages) && !errors.Is(err, ErrMaxBytesExceeded) {
res.err = err
}
res.done = true
return
}
if !userMsg {
continue
}
res.msgs <- p.jetStream.toJSMsg(msg)
meta, err := msg.Metadata()
if err != nil {
res.err = fmt.Errorf("parsing message metadata: %s", err)
}
res.sseq = meta.Sequence.Stream
receivedMsgs++
if req.MaxBytes != 0 {
receivedBytes += msgSize(msg)
}
case <-time.After(req.Expires + 1*time.Second):
res.err = fmt.Errorf("fetch timed out")
res.done = true
return
}
}
}(res)
return res, nil
}
func (fr *fetchResult) Messages() <-chan Msg {
return fr.msgs
}
func (fr *fetchResult) Error() error {
return fr.err
}
func (p *pullConsumer) Next(opts ...FetchOpt) (Msg, error) {
res, err := p.Fetch(1, opts...)
if err != nil {
return nil, err
}
msg := <-res.Messages()
if msg != nil {
return msg, nil
}
return nil, res.Error()
}
func (s *pullSubscription) pullMessages(subject string) {
for {
select {
case req := <-s.fetchNext:
atomic.StoreUint32(&s.fetchInProgress, 1)
if err := s.pull(req, subject); err != nil {
if errors.Is(err, ErrMsgIteratorClosed) {
s.cleanupSubscriptionAndRestoreConnHandler()
return
}
s.errs <- err
}
atomic.StoreUint32(&s.fetchInProgress, 0)
case <-s.done:
s.cleanupSubscriptionAndRestoreConnHandler()
return
}
}
}
func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor {
if dur == 0 {
return nil
}
return &hbMonitor{
timer: time.AfterFunc(2*dur, func() {
s.errs <- ErrNoHeartbeat
}),
}
}
func (s *pullSubscription) cleanupSubscriptionAndRestoreConnHandler() {
s.consumer.Lock()
defer s.consumer.Unlock()
if s.subscription == nil {
return
}
if s.hbMonitor != nil {
s.hbMonitor.Stop()
}
s.subscription.Unsubscribe()
close(s.connected)
close(s.disconnected)
s.subscription = nil
delete(s.consumer.subscriptions, s.id)
}
func msgSize(msg *nats.Msg) int {
if msg == nil {
return 0
}
size := len(msg.Subject) + len(msg.Reply) + len(msg.Data)
return size
}
// pull sends a pull request to the server and waits for messages using a subscription from [pullSubscription].
// Messages will be fetched up to given batch_size or until there are no more messages or timeout is returned
func (s *pullSubscription) pull(req *pullRequest, subject string) error {
s.consumer.Lock()
defer s.consumer.Unlock()
if atomic.LoadUint32(&s.closed) == 1 {
return ErrMsgIteratorClosed
}
if req.Batch < 1 {
return fmt.Errorf("%w: batch size must be at least 1", nats.ErrInvalidArg)
}
reqJSON, err := json.Marshal(req)
if err != nil {
return err
}
reply := s.subscription.Subject
if err := s.consumer.jetStream.conn.PublishRequest(subject, reply, reqJSON); err != nil {
return err
}
return nil
}
func parseConsumeOpts(opts ...PullConsumeOpt) (*consumeOpts, error) {
consumeOpts := &consumeOpts{
MaxMessages: unset,
MaxBytes: unset,
Expires: DefaultExpires,
Heartbeat: unset,
ReportMissingHeartbeats: true,
}
for _, opt := range opts {
if err := opt.configureConsume(consumeOpts); err != nil {
return nil, err
}
}
if err := consumeOpts.setDefaults(); err != nil {
return nil, err
}
return consumeOpts, nil
}
func parseMessagesOpts(opts ...PullMessagesOpt) (*consumeOpts, error) {
consumeOpts := &consumeOpts{
MaxMessages: unset,
MaxBytes: unset,
Expires: DefaultExpires,
Heartbeat: unset,
ReportMissingHeartbeats: true,
}
for _, opt := range opts {
if err := opt.configureMessages(consumeOpts); err != nil {
return nil, err
}
}
if err := consumeOpts.setDefaults(); err != nil {
return nil, err
}
return consumeOpts, nil
}
func (consumeOpts *consumeOpts) setDefaults() error {
if consumeOpts.MaxBytes != unset && consumeOpts.MaxMessages != unset {
return fmt.Errorf("only one of MaxMessages and MaxBytes can be specified")
}
if consumeOpts.MaxBytes != unset {
// when max_bytes is used, set batch size to a very large number
consumeOpts.MaxMessages = 1000000
} else if consumeOpts.MaxMessages != unset {
consumeOpts.MaxBytes = 0
} else {
if consumeOpts.MaxBytes == unset {
consumeOpts.MaxBytes = 0
}
if consumeOpts.MaxMessages == unset {
consumeOpts.MaxMessages = DefaultMaxMessages
}
}
if consumeOpts.ThresholdMessages == 0 {
consumeOpts.ThresholdMessages = int(math.Ceil(float64(consumeOpts.MaxMessages) / 2))
}
if consumeOpts.ThresholdBytes == 0 {
consumeOpts.ThresholdBytes = int(math.Ceil(float64(consumeOpts.MaxBytes) / 2))
}
if consumeOpts.Heartbeat == unset {
consumeOpts.Heartbeat = consumeOpts.Expires / 2
if consumeOpts.Heartbeat > 30*time.Second {
consumeOpts.Heartbeat = 30 * time.Second
}
}
if consumeOpts.Heartbeat > consumeOpts.Expires/2 {
return fmt.Errorf("the value of Heartbeat must be less than 50%% of expiry")
}
return nil
}