Add publish options to BatchPublisher.Add

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
This commit is contained in:
Piotr Piotrowski
2025-09-12 09:04:57 +02:00
parent aa93df6c15
commit 3c7f272bf7
4 changed files with 373 additions and 22 deletions

View File

@@ -19,6 +19,7 @@ import (
"fmt"
"strconv"
"sync"
"time"
"github.com/nats-io/nats.go"
"github.com/nats-io/nuid"
@@ -30,19 +31,19 @@ type (
// with the final message which includes a commit header.
BatchPublisher interface {
// Add publishes a message to the batch with the given subject and data.
Add(subject string, data []byte) error
Add(subject string, data []byte, opts ...BatchMsgOpt) error
// AddMsg publishes a message to the batch.
AddMsg(msg *nats.Msg) error
AddMsg(msg *nats.Msg, opts ...BatchMsgOpt) error
// Commit publishes the final message with the given subject and data,
// and commits the batch. Returns a BatchAck containing the acknowledgment
// from the server.
Commit(ctx context.Context, subject string, data []byte) (*BatchAck, error)
Commit(ctx context.Context, subject string, data []byte, opts ...BatchMsgOpt) (*BatchAck, error)
// CommitMsg publishes the final message and commits the batch.
// Returns a BatchAck containing the acknowledgment from the server.
CommitMsg(ctx context.Context, msg *nats.Msg) (*BatchAck, error)
CommitMsg(ctx context.Context, msg *nats.Msg, opts ...BatchMsgOpt) (*BatchAck, error)
// Discard cancels the batch without committing.
// The server will abandon the batch after a timeout.
@@ -91,6 +92,17 @@ type (
BatchID string `json:"batch,omitempty"`
BatchSize int `json:"count,omitempty"`
}
// BatchMsgOpt is an option for configuring batch message publishing.
BatchMsgOpt func(*batchMsgOpts) error
batchMsgOpts struct {
ttl time.Duration
stream string
lastSeq *uint64
lastSubjectSeq *uint64
lastSubject string
}
)
// BatchPublisher creates a new batch publisher for publishing messages in batches.
@@ -102,12 +114,12 @@ func (js *jetStream) BatchPublisher() (BatchPublisher, error) {
}
// Add publishes a message to the batch with the given subject and data.
func (b *batchPublisher) Add(subject string, data []byte) error {
return b.AddMsg(&nats.Msg{Subject: subject, Data: data})
func (b *batchPublisher) Add(subject string, data []byte, opts ...BatchMsgOpt) error {
return b.AddMsg(&nats.Msg{Subject: subject, Data: data}, opts...)
}
// AddMsg publishes a message to the batch.
func (b *batchPublisher) AddMsg(msg *nats.Msg) error {
func (b *batchPublisher) AddMsg(msg *nats.Msg, opts ...BatchMsgOpt) error {
b.mu.Lock()
defer b.mu.Unlock()
@@ -119,21 +131,50 @@ func (b *batchPublisher) AddMsg(msg *nats.Msg) error {
msg.Header = nats.Header{}
}
// Process batch message options
o := batchMsgOpts{}
for _, opt := range opts {
if err := opt(&o); err != nil {
return err
}
}
// Validate ExpectLastSequence options can only be used on first message
if b.sequence > 0 && o.lastSeq != nil {
return ErrBatchExpectLastSequenceNotFirst
}
if o.ttl > 0 {
msg.Header.Set(MsgTTLHeader, o.ttl.String())
}
if o.stream != "" {
msg.Header.Set(ExpectedStreamHeader, o.stream)
}
if o.lastSubjectSeq != nil {
msg.Header.Set(ExpectedLastSubjSeqHeader, strconv.FormatUint(*o.lastSubjectSeq, 10))
}
if o.lastSubject != "" {
msg.Header.Set(ExpectedLastSubjSeqSubjHeader, o.lastSubject)
msg.Header.Set(ExpectedLastSubjSeqHeader, strconv.FormatUint(*o.lastSubjectSeq, 10))
}
if o.lastSeq != nil {
msg.Header.Set(ExpectedLastSeqHeader, strconv.FormatUint(*o.lastSeq, 10))
}
b.sequence++
msg.Header.Set(BatchIDHeader, b.batchID)
msg.Header.Set(BatchSeqHeader, strconv.FormatUint(uint64(b.sequence), 10))
// Publish immediately
return b.js.conn.PublishMsg(msg)
}
// Commit publishes the final message and commits the batch.
func (b *batchPublisher) Commit(ctx context.Context, subject string, data []byte) (*BatchAck, error) {
return b.CommitMsg(ctx, &nats.Msg{Subject: subject, Data: data})
func (b *batchPublisher) Commit(ctx context.Context, subject string, data []byte, opts ...BatchMsgOpt) (*BatchAck, error) {
return b.CommitMsg(ctx, &nats.Msg{Subject: subject, Data: data}, opts...)
}
// CommitMsg publishes the final message and commits the batch.
func (b *batchPublisher) CommitMsg(ctx context.Context, msg *nats.Msg) (*BatchAck, error) {
func (b *batchPublisher) CommitMsg(ctx context.Context, msg *nats.Msg, opts ...BatchMsgOpt) (*BatchAck, error) {
ctx, cancel := b.js.wrapContextWithoutDeadline(ctx)
if cancel != nil {
defer cancel()
@@ -144,18 +185,47 @@ func (b *batchPublisher) CommitMsg(ctx context.Context, msg *nats.Msg) (*BatchAc
if b.closed {
return nil, ErrBatchClosed
}
// Process batch message options and convert to PublishOpt
o := batchMsgOpts{}
for _, opt := range opts {
if err := opt(&o); err != nil {
return nil, err
}
}
// Validate ExpectLastSequence options can only be used on first message
if b.sequence > 0 && o.lastSeq != nil {
return nil, ErrBatchExpectLastSequenceNotFirst
}
// Convert batch options to publish options for commit
var pubOpts []PublishOpt
if o.ttl > 0 {
pubOpts = append(pubOpts, WithMsgTTL(o.ttl))
}
if o.stream != "" {
pubOpts = append(pubOpts, WithExpectStream(o.stream))
}
if o.lastSeq != nil {
pubOpts = append(pubOpts, WithExpectLastSequence(*o.lastSeq))
}
if o.lastSubject != "" && o.lastSubjectSeq != nil {
pubOpts = append(pubOpts, WithExpectLastSequenceForSubject(*o.lastSubjectSeq, o.lastSubject))
} else if o.lastSubjectSeq != nil {
pubOpts = append(pubOpts, WithExpectLastSequencePerSubject(*o.lastSubjectSeq))
}
b.sequence++
// Check for unsupported headers
if msg.Header == nil {
msg.Header = nats.Header{}
}
b.sequence++
msg.Header.Set(BatchIDHeader, b.batchID)
msg.Header.Set(BatchSeqHeader, strconv.FormatInt(int64(b.sequence), 10))
msg.Header.Set(BatchCommitHeader, "1")
resp, err := b.js.publishWithOptions(ctx, msg, nil)
resp, err := b.js.publishWithOptions(ctx, msg, pubOpts)
if err != nil {
return nil, err
}
@@ -174,7 +244,6 @@ func (b *batchPublisher) CommitMsg(ctx context.Context, msg *nats.Msg) (*BatchAc
return nil, ErrInvalidBatchAck
}
// Return BatchAck with server-provided values
return &BatchAck{
Stream: batchResp.PubAck.Stream,
Sequence: batchResp.PubAck.Sequence,
@@ -215,7 +284,6 @@ func (b *batchPublisher) IsClosed() bool {
// PublishMsgBatch publishes a batch of messages to a Stream and waits for an ack for the commit.
func (js *jetStream) PublishMsgBatch(ctx context.Context, messages []*nats.Msg) (*BatchAck, error) {
// Batch publish
var batchAck *BatchAck
var err error
msgs := len(messages)

View File

@@ -360,6 +360,9 @@ var (
// invalid.
ErrInvalidBatchAck JetStreamError = &jsError{message: "invalid jetstream batch publish response"}
// ErrBatchExpectLastSequenceNotFirst is returned when ExpectLastSequence options are used on non-first message in batch.
ErrBatchExpectLastSequenceNotFirst = &jsError{message: "ExpectLastSequence options can only be used on first message in batch"}
// ErrInvalidKey is returned when attempting to create a key with an invalid
// name.
ErrInvalidKey JetStreamError = &jsError{message: "invalid key"}

View File

@@ -667,3 +667,58 @@ func WithStallWait(ttl time.Duration) PublishOpt {
return nil
}
}
// WithBatchMsgTTL sets per msg TTL for batch messages.
// Requires [StreamConfig.AllowMsgTTL] to be enabled.
func WithBatchMsgTTL(dur time.Duration) BatchMsgOpt {
return func(opts *batchMsgOpts) error {
opts.ttl = dur
return nil
}
}
// WithBatchExpectStream sets the expected stream the message should be published to.
// If the message is published to a different stream server will reject the
// message and publish will fail.
func WithBatchExpectStream(stream string) BatchMsgOpt {
return func(opts *batchMsgOpts) error {
opts.stream = stream
return nil
}
}
// WithBatchExpectLastSequence sets the expected sequence number the last message
// on a stream should have. If the last message has a different sequence number
// server will reject the message and publish will fail.
func WithBatchExpectLastSequence(seq uint64) BatchMsgOpt {
return func(opts *batchMsgOpts) error {
opts.lastSeq = &seq
return nil
}
}
// WithBatchExpectLastSequencePerSubject sets the expected sequence number the last
// message on a subject the message is published to. If the last message on a
// subject has a different sequence number server will reject the message and
// publish will fail.
func WithBatchExpectLastSequencePerSubject(seq uint64) BatchMsgOpt {
return func(opts *batchMsgOpts) error {
opts.lastSubjectSeq = &seq
return nil
}
}
// WithBatchExpectLastSequenceForSubject sets the sequence and subject for which the
// last sequence number should be checked. If the last message on a subject
// has a different sequence number server will reject the message and publish
// will fail.
func WithBatchExpectLastSequenceForSubject(seq uint64, subject string) BatchMsgOpt {
return func(opts *batchMsgOpts) error {
if subject == "" {
return fmt.Errorf("%w: subject cannot be empty", ErrInvalidOption)
}
opts.lastSubjectSeq = &seq
opts.lastSubject = subject
return nil
}
}

View File

@@ -23,6 +23,72 @@ import (
"github.com/nats-io/nats.go/jetstream"
)
func TestBatchPublishLastSequence(t *testing.T) {
t.Skip("Skipping until expected last sequence is fixed in server")
s := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, s)
nc, js := jsClient(t, s)
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Create a stream with batch publishing enabled
cfg := jetstream.StreamConfig{
Name: "TEST",
Subjects: []string{"test.>"},
AllowAtomicPublish: true,
}
stream, err := js.CreateStream(ctx, cfg)
if err != nil {
t.Fatalf("Unexpected error creating stream: %v", err)
}
// publish a message to have a last sequence
_, err = js.Publish(ctx, "test.foo", []byte("hello"))
if err != nil {
t.Fatalf("Unexpected error publishing message: %v", err)
}
batch, err := js.BatchPublisher()
if err != nil {
t.Fatalf("Unexpected error creating batch publisher: %v", err)
}
// Add first message with ExpectLastSequence = 1
if err := batch.Add("test.1", []byte("message 1"), jetstream.WithBatchExpectLastSequence(1)); err != nil {
t.Fatalf("Unexpected error adding first message with ExpectLastSequence: %v", err)
}
// Add second message without ExpectLastSequence
if err := batch.Add("test.2", []byte("message 2")); err != nil {
t.Fatalf("Unexpected error adding second message: %v", err)
}
// Commit third message
ack, err := batch.Commit(ctx, "test.3", []byte("message 3"))
if err != nil {
t.Fatalf("Unexpected error committing batch: %v", err)
}
if ack == nil {
t.Fatal("Expected non-nil BatchAck")
}
// Verify ack contains expected stream
if ack.Stream != "TEST" {
t.Fatalf("Expected stream name to be TEST, got %s", ack.Stream)
}
info, err := stream.Info(ctx)
if err != nil {
t.Fatalf("Unexpected error getting stream info: %v", err)
}
if info.State.Msgs != 4 {
t.Fatalf("Expected 4 messages in the stream, got %d", info.State.Msgs)
}
}
func TestBatchPublisher(t *testing.T) {
t.Run("basic", func(t *testing.T) {
@@ -108,6 +174,146 @@ func TestBatchPublisher(t *testing.T) {
}
})
t.Run("with options", func(t *testing.T) {
t.Skip("Skipping until expected last sequence is fixed in server")
s := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, s)
nc, js := jsClient(t, s)
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Create a stream with batch publishing and TTL enabled
cfg := jetstream.StreamConfig{
Name: "TEST",
Subjects: []string{"test.>"},
AllowAtomicPublish: true,
AllowMsgTTL: true,
}
stream, err := js.CreateStream(ctx, cfg)
if err != nil {
t.Fatalf("Unexpected error creating stream: %v", err)
}
for range 5 {
if _, err := js.Publish(ctx, "test.foo", []byte("hello")); err != nil {
t.Fatalf("Unexpected error publishing message: %v", err)
}
}
info, err := stream.Info(ctx)
if err != nil {
t.Fatalf("Unexpected error getting stream info: %v", err)
}
if info.State.Msgs != 5 {
t.Fatalf("Expected 5 messages in the stream, got %d", info.State.Msgs)
}
time.Sleep(time.Second)
batch, err := js.BatchPublisher()
if err != nil {
t.Fatalf("Unexpected error creating batch publisher: %v", err)
}
// Add first message with TTL and ExpectLastSequence (allowed on first message)
if err := batch.Add("test.1", []byte("message 1"), jetstream.WithBatchMsgTTL(5*time.Second), jetstream.WithBatchExpectLastSequence(5)); err != nil {
t.Fatalf("Unexpected error adding first message with options: %v", err)
}
// Add second message with expected stream (no ExpectLastSequence)
if err := batch.AddMsg(&nats.Msg{
Subject: "test.2",
Data: []byte("message 2"),
}, jetstream.WithBatchExpectStream("TEST")); err != nil {
t.Fatalf("Unexpected error adding second message with expected stream: %v", err)
}
// Commit third message
ack, err := batch.Commit(ctx, "test.3", []byte("message 3"))
if err != nil {
t.Fatalf("Unexpected error committing batch with expected sequence: %v", err)
}
if ack == nil {
t.Fatal("Expected non-nil BatchAck")
}
// Verify ack contains expected stream
if ack.Stream != "TEST" {
t.Fatalf("Expected stream name to be TEST, got %s", ack.Stream)
}
info, err = stream.Info(ctx)
if err != nil {
t.Fatalf("Unexpected error getting stream info: %v", err)
}
if info.State.Msgs != 8 {
t.Fatalf("Expected 8 messages in the stream, got %d", info.State.Msgs)
}
})
t.Run("expect last sequence validation", func(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, s)
nc, js := jsClient(t, s)
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Create a stream with batch publishing enabled
cfg := jetstream.StreamConfig{
Name: "TEST",
Subjects: []string{"test.>"},
AllowAtomicPublish: true,
}
_, err := js.CreateStream(ctx, cfg)
if err != nil {
t.Fatalf("Unexpected error creating stream: %v", err)
}
batch, err := js.BatchPublisher()
if err != nil {
t.Fatalf("Unexpected error creating batch publisher: %v", err)
}
// First message with ExpectLastSequence should work
if err := batch.Add("test.1", []byte("message 1"), jetstream.WithBatchExpectLastSequence(0)); err != nil {
t.Fatalf("Unexpected error adding first message with ExpectLastSequence: %v", err)
}
// Second message with ExpectLastSequence should fail
if err := batch.Add("test.2", []byte("message 2"), jetstream.WithBatchExpectLastSequence(1)); err == nil {
t.Fatal("Expected error when using ExpectLastSequence on non-first message")
} else if !errors.Is(err, jetstream.ErrBatchExpectLastSequenceNotFirst) {
t.Fatalf("Expected ErrBatchExpectLastSequenceNotFirst, got %v", err)
}
// Second message without ExpectLastSequence should work
if err := batch.Add("test.2", []byte("message 2")); err != nil {
t.Fatalf("Unexpected error adding second message: %v", err)
}
// Commit with ExpectLastSequence should fail (not first message)
if _, err := batch.Commit(ctx, "test.3", []byte("message 3"), jetstream.WithBatchExpectLastSequence(2)); err == nil {
t.Fatal("Expected error when using ExpectLastSequence on commit (non-first message)")
} else if !errors.Is(err, jetstream.ErrBatchExpectLastSequenceNotFirst) {
t.Fatalf("Expected ErrBatchExpectLastSequenceNotFirst, got %v", err)
}
// Commit without ExpectLastSequence should work
ack, err := batch.Commit(ctx, "test.3", []byte("message 3"))
if err != nil {
t.Fatalf("Unexpected error committing batch: %v", err)
}
if ack == nil {
t.Fatal("Expected non-nil BatchAck")
}
})
t.Run("too many outstanding batches", func(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, s)
@@ -258,17 +464,36 @@ func TestBatchPublisher(t *testing.T) {
}
// Add messages until we exceed the max batch size (1000 messages)
for i := 0; i < 1000; i++ {
for i := 0; i < 999; i++ {
err = batch.Add("test.1", []byte("message 1"))
if err != nil {
t.Fatalf("Unexpected error adding message to batch: %v", err)
}
}
// commit is msg 1001
// commit is msg 1000 (within limit)
_, err = batch.Commit(ctx, "test.2", []byte("message 2"))
if !errors.Is(err, jetstream.ErrBatchPublishExceedsLimit) {
t.Fatalf("Expected ErrBatchPublishExceedsLimit, got %v", err)
if err != nil {
t.Fatalf("Unexpected error committing batch: %v", err)
}
// Try to create another batch and add 1001 messages
batch2, err := js.BatchPublisher()
if err != nil {
t.Fatalf("Unexpected error creating second batch publisher: %v", err)
}
for i := 0; i < 1000; i++ {
err = batch2.Add("test.1", []byte("message 1"))
if err != nil {
t.Fatalf("Unexpected error adding message to batch: %v", err)
}
}
// This should be message 1001 and should fail with incomplete error
_, err = batch2.Commit(ctx, "test.2", []byte("message 2"))
if !errors.Is(err, jetstream.ErrBatchPublishIncomplete) {
t.Fatalf("Expected ErrBatchPublishIncomplete, got %v", err)
}
})
@@ -428,8 +653,8 @@ func TestPublishMsgBatch(t *testing.T) {
}
_, err = js.PublishMsgBatch(ctx, messages)
if !errors.Is(err, jetstream.ErrBatchPublishExceedsLimit) {
t.Fatalf("Expected ErrBatchPublishExceedsLimit publishing too many messages, got %v", err)
if !errors.Is(err, jetstream.ErrBatchPublishIncomplete) {
t.Fatalf("Expected ErrBatchPublishIncomplete publishing too many messages, got %v", err)
}
})
}