[FIXED] Consumer.Next() hangs after connection is closed (#1883)

* [FIXED] JetStream consumer hanging after connection closure

Fixes #1875

JetStream consumer operations would hang indefinitely after server
shutdown or connection closure.

Changes:
- Fixed core NATS closed handler bug for ChanSubscription and SyncSubscription
- Added ErrConnectionClosed error for better error context
- Enhanced Consume() to report connection closure via ConsumeErrHandler
- Improved Next() to wrap ErrMsgIteratorClosed with ErrConnectionClosed

Both Messages().Next() and consumer.Consume() now properly handle connection
closure and provide appropriate error reporting to applications.

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>

* Fix consume not sending conn closed error, improve tests

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>

---------

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
This commit is contained in:
Piotr Piotrowski
2025-06-20 15:44:55 +02:00
committed by GitHub
parent 8a48023f77
commit 0879849951
7 changed files with 381 additions and 13 deletions

2
.gitignore vendored
View File

@@ -42,4 +42,4 @@ _testmain.go
.idea
# VS Code
.vscode
.vscode

View File

@@ -259,6 +259,14 @@ var (
// closed iterator.
ErrMsgIteratorClosed JetStreamError = &jsError{message: "messages iterator closed"}
// ErrConnectionClosed is returned when JetStream operations fail due to
// underlying connection being closed.
ErrConnectionClosed JetStreamError = &jsError{message: "connection closed"}
// ErrServerShutdown is returned when pull request fails due to server
// shutdown.
ErrServerShutdown JetStreamError = &jsError{message: "server shutdown"}
// ErrOrderedConsumerReset is returned when resetting ordered consumer fails
// due to too many attempts.
ErrOrderedConsumerReset JetStreamError = &jsError{message: "recreating ordered consumer"}

View File

@@ -141,13 +141,13 @@ type (
)
const (
controlMsg = "100"
badRequest = "400"
noMessages = "404"
reqTimeout = "408"
maxBytesExceeded = "409"
noResponders = "503"
pinIdMismatch = "423"
controlMsg = "100"
badRequest = "400"
noMessages = "404"
reqTimeout = "408"
conflict = "409"
noResponders = "503"
pinIdMismatch = "423"
)
// Headers used when publishing messages.
@@ -424,7 +424,7 @@ func checkMsg(msg *nats.Msg) (bool, error) {
return false, nil
case pinIdMismatch:
return false, ErrPinIDMismatch
case maxBytesExceeded:
case conflict:
if strings.Contains(strings.ToLower(descr), "message size exceeds maxbytes") {
return false, ErrMaxBytesExceeded
}
@@ -437,6 +437,9 @@ func checkMsg(msg *nats.Msg) (bool, error) {
if strings.Contains(strings.ToLower(descr), "leadership change") {
return false, ErrConsumerLeadershipChanged
}
if strings.Contains(strings.ToLower(descr), "server shutdown") {
return false, ErrServerShutdown
}
}
return false, fmt.Errorf("nats: %s", msg.Header.Get("Description"))
}

View File

@@ -222,7 +222,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
fetchNext: make(chan *pullRequest, 1),
consumeOpts: consumeOpts,
}
sub.connStatusChanged = p.js.conn.StatusChanged(nats.CONNECTED, nats.RECONNECTING)
sub.connStatusChanged = p.js.conn.StatusChanged(nats.CONNECTED, nats.RECONNECTING, nats.CLOSED)
sub.hbMonitor = sub.scheduleHeartbeatCheck(consumeOpts.Heartbeat)
@@ -332,13 +332,13 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
if !ok {
continue
}
if status == nats.RECONNECTING {
switch status {
case nats.RECONNECTING:
if sub.hbMonitor != nil {
sub.hbMonitor.Stop()
}
isConnected = false
}
if status == nats.CONNECTED {
case nats.CONNECTED:
sub.Lock()
if !isConnected {
isConnected = true
@@ -362,6 +362,9 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
sub.resetPendingMsgs()
}
sub.Unlock()
case nats.CLOSED:
sub.errs <- ErrConnectionClosed
}
case err := <-sub.errs:
sub.Lock()
@@ -389,6 +392,9 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
sub.resetPendingMsgs()
}
sub.Unlock()
if errors.Is(err, ErrConnectionClosed) {
sub.Stop()
}
case <-sub.done:
return
}
@@ -569,6 +575,10 @@ func (s *pullSubscription) Next() (Msg, error) {
drainMode := s.draining.Load() == 1
closed := s.closed.Load() == 1
if closed && !drainMode {
// Check if iterator was closed due to connection closure
if s.consumer.js.conn.IsClosed() {
return nil, fmt.Errorf("%w: %w", ErrMsgIteratorClosed, ErrConnectionClosed)
}
return nil, ErrMsgIteratorClosed
}
hbMonitor := s.scheduleHeartbeatCheck(s.consumeOpts.Heartbeat)
@@ -592,6 +602,10 @@ func (s *pullSubscription) Next() (Msg, error) {
// if msgs channel is closed, it means that subscription was either drained or stopped
s.consumer.subs.Delete(s.id)
s.draining.CompareAndSwap(1, 0)
// Check if iterator was closed due to connection closure
if s.consumer.js.conn.IsClosed() {
return nil, fmt.Errorf("%w: %w", ErrMsgIteratorClosed, ErrConnectionClosed)
}
return nil, ErrMsgIteratorClosed
}
if hbMonitor != nil {

View File

@@ -3483,3 +3483,268 @@ func TestPullConsumerNext(t *testing.T) {
}
})
}
func TestPullConsumerConnectionClosed(t *testing.T) {
t.Run("messages", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
Name: "test-stream",
Subjects: []string{"test.>"},
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
Name: "test-consumer",
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := consumer.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
errC := make(chan error, 1)
go func() {
_, err := msgs.Next()
errC <- err
}()
time.Sleep(100 * time.Millisecond)
nc.Close()
select {
case err := <-errC:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Expected error to contain ErrMsgIteratorClosed, got: %v", err)
}
if !errors.Is(err, jetstream.ErrConnectionClosed) {
t.Fatalf("Expected error to contain ErrConnectionClosed, got: %v", err)
}
case <-time.After(10 * time.Second):
t.Fatal("Next() hung indefinitely after connection closed")
}
})
t.Run("consume", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
Name: "test-stream",
Subjects: []string{"test.>"},
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
Name: "test-consumer",
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
errC := make(chan error, 1)
consumeCtx, err := consumer.Consume(func(msg jetstream.Msg) {
}, jetstream.ConsumeErrHandler(func(cc jetstream.ConsumeContext, err error) {
errC <- err
}))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer consumeCtx.Stop()
time.Sleep(100 * time.Millisecond)
nc.Close()
select {
case err := <-errC:
if !errors.Is(err, jetstream.ErrConnectionClosed) {
t.Fatalf("Expected ErrConnectionClosed, got: %v", err)
}
select {
case <-consumeCtx.Closed():
case <-time.After(3 * time.Second):
t.Fatal("Received error but Consume context was not closed")
}
case <-time.After(3 * time.Second):
t.Fatal("Consume did not return error after connection closed")
}
})
}
func TestPullConsumerMaxReconnectsExceeded(t *testing.T) {
t.Run("messages", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL(),
nats.MaxReconnects(3),
nats.ReconnectWait(100*time.Millisecond),
)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
Name: "test-stream",
Subjects: []string{"test.>"},
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
Name: "test-consumer",
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs, err := consumer.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
errC := make(chan error, 1)
go func() {
_, err := msgs.Next()
errC <- err
}()
time.Sleep(100 * time.Millisecond)
shutdownJSServerAndRemoveStorage(t, srv)
select {
case err := <-errC:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Expected error to contain ErrMsgIteratorClosed, got: %v", err)
}
if !errors.Is(err, jetstream.ErrConnectionClosed) {
t.Fatalf("Expected error to contain ErrConnectionClosed, got: %v", err)
}
case <-time.After(15 * time.Second):
t.Fatal("Next() hung after reconnection attempts exhausted")
}
})
t.Run("consume", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL(),
nats.MaxReconnects(3),
nats.ReconnectWait(100*time.Millisecond),
)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
stream, err := js.CreateStream(ctx, jetstream.StreamConfig{
Name: "test-stream",
Subjects: []string{"test.>"},
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
consumer, err := stream.CreateConsumer(ctx, jetstream.ConsumerConfig{
Name: "test-consumer",
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
errC := make(chan error, 1)
consumeCtx, err := consumer.Consume(func(msg jetstream.Msg) {
}, jetstream.ConsumeErrHandler(func(cc jetstream.ConsumeContext, err error) {
errC <- err
}))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer consumeCtx.Stop()
time.Sleep(100 * time.Millisecond)
shutdownJSServerAndRemoveStorage(t, srv)
// first, we should receive Server Shutdown error form server
select {
case err := <-errC:
if !errors.Is(err, jetstream.ErrServerShutdown) {
t.Fatalf("Expected error to contain ErrServerShutdown, got: %v", err)
}
// consume context should not be closed yet because client tries to reconnect
select {
case <-consumeCtx.Closed():
t.Fatalf("Consume context should not be closed after server shutdown error")
case <-time.After(100 * time.Millisecond):
}
case <-time.After(3 * time.Second):
t.Fatal("Consume did not return error after server shutdown")
}
// now we should receive connection closed error after all reconnection attempts exhausted
// and consume context should be closed
select {
case err := <-errC:
if !errors.Is(err, jetstream.ErrConnectionClosed) {
t.Fatalf("Expected ErrConnectionClosed, got: %v", err)
}
select {
case <-consumeCtx.Closed():
case <-time.After(3 * time.Second):
t.Fatal("Received error but Consume context was not closed")
}
case <-time.After(3 * time.Second):
t.Fatal("Consume did not return error after connection closed")
}
})
}

13
nats.go
View File

@@ -5509,6 +5509,14 @@ func (nc *Conn) close(status Status, doCBs bool, err error) {
close(s.mch)
}
s.mch = nil
// Call closed handler for non-AsyncSubscription types (AsyncSubscription handlers
// are called by waitForMsgs when it exits)
var done func(string)
if s.typ != AsyncSubscription && s.pDone != nil {
done = s.pDone
}
// Mark as invalid, for signaling to waitForMsgs
s.closed = true
// Mark connection closed in subscription
@@ -5519,6 +5527,11 @@ func (nc *Conn) close(status Status, doCBs bool, err error) {
}
s.mu.Unlock()
// Call the closed handler outside the lock to avoid potential deadlocks
if done != nil {
done(s.Subject)
}
}
nc.subs = nil
nc.subsMu.Unlock()

View File

@@ -1771,6 +1771,71 @@ func TestMaxSubscriptionsExceeded(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}
func TestClosedHandlerOnConnectionClose(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()
closedHandlerCalled := make(chan struct{}, 1)
t.Run("subscribe", func(t *testing.T) {
nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
sub, err := nc.Subscribe("test.subject", func(m *nats.Msg) {})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
sub.SetClosedHandler(func(subject string) {
closedHandlerCalled <- struct{}{}
})
nc.Close()
WaitOnChannel(t, closedHandlerCalled, struct{}{})
})
t.Run("chan subscribe", func(t *testing.T) {
nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgCh := make(chan *nats.Msg, 64)
sub, err := nc.ChanSubscribe("test.subject", msgCh)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
sub.SetClosedHandler(func(subject string) {
closedHandlerCalled <- struct{}{}
})
nc.Close()
WaitOnChannel(t, closedHandlerCalled, struct{}{})
})
t.Run("sync subscribe", func(t *testing.T) {
nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
sub, err := nc.SubscribeSync("test.subject")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
sub.SetClosedHandler(func(subject string) {
closedHandlerCalled <- struct{}{}
})
nc.Close()
WaitOnChannel(t, closedHandlerCalled, struct{}{})
})
}
func TestSubscribeSyncPermissionError(t *testing.T) {
conf := createConfFile(t, []byte(`
listen: 127.0.0.1:-1