mirror of
https://github.com/nats-io/nats.go.git
synced 2025-09-26 20:41:41 +08:00
[FIXED] Deadlock when accessing subscriptions map on consumer (#1671)
This fixes an issue where a deadlock could occur when calling `Stop()` or `Drain()` on `ConsumeContext` or `MessagesContext` and then calling `Consume` or `Messages` immediately. Switched to using a type-safe implementation of `sync.Map` for subscriptions map instead of locking the whole consumer state. Additionally, changed the type of atomic flags from `uint32` to `atomic.UInt32` to avoid accidental non-atomic reads/writes. Signed-off-by: Piotr Piotrowski <piotr@synadia.com> --------- Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
This commit is contained in:
14
go_test.mod
14
go_test.mod
@@ -1,23 +1,25 @@
|
||||
module github.com/nats-io/nats.go
|
||||
|
||||
go 1.19
|
||||
go 1.21
|
||||
|
||||
toolchain go1.22.5
|
||||
|
||||
require (
|
||||
github.com/golang/protobuf v1.4.2
|
||||
github.com/klauspost/compress v1.17.8
|
||||
github.com/klauspost/compress v1.17.9
|
||||
github.com/nats-io/jwt v1.2.2
|
||||
github.com/nats-io/nats-server/v2 v2.10.16
|
||||
github.com/nats-io/nats-server/v2 v2.10.17
|
||||
github.com/nats-io/nkeys v0.4.7
|
||||
github.com/nats-io/nuid v1.0.1
|
||||
go.uber.org/goleak v1.3.0
|
||||
golang.org/x/text v0.15.0
|
||||
golang.org/x/text v0.16.0
|
||||
google.golang.org/protobuf v1.23.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/minio/highwayhash v1.0.2 // indirect
|
||||
github.com/nats-io/jwt/v2 v2.5.7 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
golang.org/x/crypto v0.24.0 // indirect
|
||||
golang.org/x/sys v0.21.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
)
|
||||
|
24
go_test.sum
24
go_test.sum
@@ -1,4 +1,5 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
|
||||
@@ -10,38 +11,40 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
|
||||
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=
|
||||
github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY=
|
||||
github.com/nats-io/jwt v1.2.2 h1:w3GMTO969dFg+UOKTmmyuu7IGdusK+7Ytlt//OYH/uU=
|
||||
github.com/nats-io/jwt v1.2.2/go.mod h1:/xX356yQA6LuXI9xWW7mZNpxgF2mBmGecH+Fj34sP5Q=
|
||||
github.com/nats-io/jwt/v2 v2.5.7 h1:j5lH1fUXCnJnY8SsQeB/a/z9Azgu2bYIDvtPVNdxe2c=
|
||||
github.com/nats-io/jwt/v2 v2.5.7/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A=
|
||||
github.com/nats-io/nats-server/v2 v2.10.16 h1:2jXaiydp5oB/nAx/Ytf9fdCi9QN6ItIc9eehX8kwVV0=
|
||||
github.com/nats-io/nats-server/v2 v2.10.16/go.mod h1:Pksi38H2+6xLe1vQx0/EA4bzetM0NqyIHcIbmgXSkIU=
|
||||
github.com/nats-io/nats-server/v2 v2.10.17 h1:PTVObNBD3TZSNUDgzFb1qQsQX4mOgFmOuG9vhT+KBUY=
|
||||
github.com/nats-io/nats-server/v2 v2.10.17/go.mod h1:5OUyc4zg42s/p2i92zbbqXvUNsbF0ivdTLKshVMn2YQ=
|
||||
github.com/nats-io/nkeys v0.2.0/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1tqEu/s=
|
||||
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
|
||||
github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc=
|
||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
@@ -54,3 +57,4 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
|
||||
google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
73
internal/syncx/map.go
Normal file
73
internal/syncx/map.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright 2024 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package syncx
|
||||
|
||||
import "sync"
|
||||
|
||||
// Map is a type-safe wrapper around sync.Map.
|
||||
// It is safe for concurrent use.
|
||||
// The zero value of Map is an empty map ready to use.
|
||||
type Map[K comparable, V any] struct {
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Load(key K) (V, bool) {
|
||||
v, ok := m.m.Load(key)
|
||||
if !ok {
|
||||
var empty V
|
||||
return empty, false
|
||||
}
|
||||
return v.(V), true
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Store(key K, value V) {
|
||||
m.m.Store(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Delete(key K) {
|
||||
m.m.Delete(key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
|
||||
m.m.Range(func(key, value any) bool {
|
||||
return f(key.(K), value.(V))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) LoadOrStore(key K, value V) (V, bool) {
|
||||
v, loaded := m.m.LoadOrStore(key, value)
|
||||
return v.(V), loaded
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) LoadAndDelete(key K) (V, bool) {
|
||||
v, ok := m.m.LoadAndDelete(key)
|
||||
if !ok {
|
||||
var empty V
|
||||
return empty, false
|
||||
}
|
||||
return v.(V), true
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) CompareAndSwap(key K, old, new V) bool {
|
||||
return m.m.CompareAndSwap(key, old, new)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) CompareAndDelete(key K, value V) bool {
|
||||
return m.m.CompareAndDelete(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Swap(key K, value V) (V, bool) {
|
||||
previous, loaded := m.m.Swap(key, value)
|
||||
return previous.(V), loaded
|
||||
}
|
152
internal/syncx/map_test.go
Normal file
152
internal/syncx/map_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright 2024 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package syncx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMapLoad(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
|
||||
v, ok := m.Load(1)
|
||||
if !ok || v != "one" {
|
||||
t.Errorf("Load(1) = %v, %v; want 'one', true", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Load(2)
|
||||
if ok || v != "" {
|
||||
t.Errorf("Load(2) = %v, %v; want '', false", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapStore(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
|
||||
v, ok := m.Load(1)
|
||||
if !ok || v != "one" {
|
||||
t.Errorf("Load(1) after Store(1, 'one') = %v, %v; want 'one', true", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapDelete(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
m.Delete(1)
|
||||
|
||||
v, ok := m.Load(1)
|
||||
if ok || v != "" {
|
||||
t.Errorf("Load(1) after Delete(1) = %v, %v; want '', false", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapRange(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
m.Store(2, "two")
|
||||
|
||||
var keys []int
|
||||
var values []string
|
||||
m.Range(func(key int, value string) bool {
|
||||
keys = append(keys, key)
|
||||
values = append(values, value)
|
||||
return true
|
||||
})
|
||||
|
||||
if len(keys) != 2 || len(values) != 2 {
|
||||
t.Errorf("Range() keys = %v, values = %v; want 2 keys and 2 values", keys, values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapLoadOrStore(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
|
||||
v, loaded := m.LoadOrStore(1, "one")
|
||||
if loaded || v != "one" {
|
||||
t.Errorf("LoadOrStore(1, 'one') = %v, %v; want 'one', false", v, loaded)
|
||||
}
|
||||
|
||||
v, loaded = m.LoadOrStore(1, "uno")
|
||||
if !loaded || v != "one" {
|
||||
t.Errorf("LoadOrStore(1, 'uno') = %v, %v; want 'one', true", v, loaded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapLoadAndDelete(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
|
||||
v, ok := m.LoadAndDelete(1)
|
||||
if !ok || v != "one" {
|
||||
t.Errorf("LoadAndDelete(1) = %v, %v; want 'one', true", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Load(1)
|
||||
if ok || v != "" {
|
||||
t.Errorf("Load(1) after LoadAndDelete(1) = %v, %v; want '', false", v, ok)
|
||||
}
|
||||
|
||||
// Test that LoadAndDelete on a missing key returns the zero value.
|
||||
v, ok = m.LoadAndDelete(2)
|
||||
if ok || v != "" {
|
||||
t.Errorf("LoadAndDelete(2) = %v, %v; want '', false", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapCompareAndSwap(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
|
||||
ok := m.CompareAndSwap(1, "one", "uno")
|
||||
if !ok {
|
||||
t.Errorf("CompareAndSwap(1, 'one', 'uno') = false; want true")
|
||||
}
|
||||
|
||||
v, _ := m.Load(1)
|
||||
if v != "uno" {
|
||||
t.Errorf("Load(1) after CompareAndSwap = %v; want 'uno'", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapCompareAndDelete(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
|
||||
ok := m.CompareAndDelete(1, "one")
|
||||
if !ok {
|
||||
t.Errorf("CompareAndDelete(1, 'one') = false; want true")
|
||||
}
|
||||
|
||||
v, _ := m.Load(1)
|
||||
if v != "" {
|
||||
t.Errorf("Load(1) after CompareAndDelete = %v; want ''", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapSwap(t *testing.T) {
|
||||
var m Map[int, string]
|
||||
m.Store(1, "one")
|
||||
|
||||
v, loaded := m.Swap(1, "uno")
|
||||
if !loaded || v != "one" {
|
||||
t.Errorf("Swap(1, 'uno') = %v, %v; want 'one', true", v, loaded)
|
||||
}
|
||||
|
||||
v, _ = m.Load(1)
|
||||
if v != "uno" {
|
||||
t.Errorf("Load(1) after Swap = %v; want 'uno'", v)
|
||||
}
|
||||
}
|
@@ -20,6 +20,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/nats-io/nats.go/internal/syncx"
|
||||
"github.com/nats-io/nuid"
|
||||
)
|
||||
|
||||
@@ -233,12 +234,12 @@ func upsertConsumer(ctx context.Context, js *jetStream, stream string, cfg Consu
|
||||
}
|
||||
|
||||
return &pullConsumer{
|
||||
jetStream: js,
|
||||
stream: stream,
|
||||
name: resp.Name,
|
||||
durable: cfg.Durable != "",
|
||||
info: resp.ConsumerInfo,
|
||||
subscriptions: make(map[string]*pullSubscription),
|
||||
jetStream: js,
|
||||
stream: stream,
|
||||
name: resp.Name,
|
||||
durable: cfg.Durable != "",
|
||||
info: resp.ConsumerInfo,
|
||||
subs: syncx.Map[string, *pullSubscription]{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -285,12 +286,12 @@ func getConsumer(ctx context.Context, js *jetStream, stream, name string) (Consu
|
||||
}
|
||||
|
||||
cons := &pullConsumer{
|
||||
jetStream: js,
|
||||
stream: stream,
|
||||
name: name,
|
||||
durable: resp.Config.Durable != "",
|
||||
info: resp.ConsumerInfo,
|
||||
subscriptions: make(map[string]*pullSubscription, 0),
|
||||
jetStream: js,
|
||||
stream: stream,
|
||||
name: name,
|
||||
durable: resp.Config.Durable != "",
|
||||
info: resp.ConsumerInfo,
|
||||
subs: syncx.Map[string, *pullSubscription]{},
|
||||
}
|
||||
|
||||
return cons, nil
|
||||
|
@@ -276,9 +276,11 @@ func TestRetryWithBackoff(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPullConsumer_checkPending(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
givenSub *pullSubscription
|
||||
fetchInProgress bool
|
||||
shouldSend bool
|
||||
expectedPullRequest *pullRequest
|
||||
}{
|
||||
@@ -292,7 +294,6 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
ThresholdMessages: 5,
|
||||
MaxMessages: 10,
|
||||
},
|
||||
fetchInProgress: 0,
|
||||
},
|
||||
shouldSend: false,
|
||||
},
|
||||
@@ -307,7 +308,6 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
ThresholdMessages: 5,
|
||||
MaxMessages: 10,
|
||||
},
|
||||
fetchInProgress: 0,
|
||||
},
|
||||
shouldSend: true,
|
||||
expectedPullRequest: &pullRequest{
|
||||
@@ -325,9 +325,9 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
ThresholdMessages: 5,
|
||||
MaxMessages: 10,
|
||||
},
|
||||
fetchInProgress: 1,
|
||||
},
|
||||
shouldSend: false,
|
||||
fetchInProgress: true,
|
||||
shouldSend: false,
|
||||
},
|
||||
{
|
||||
name: "pending bytes below threshold, send pull request",
|
||||
@@ -341,7 +341,6 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
ThresholdBytes: 500,
|
||||
MaxBytes: 1000,
|
||||
},
|
||||
fetchInProgress: 0,
|
||||
},
|
||||
shouldSend: true,
|
||||
expectedPullRequest: &pullRequest{
|
||||
@@ -359,7 +358,6 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
ThresholdBytes: 500,
|
||||
MaxBytes: 1000,
|
||||
},
|
||||
fetchInProgress: 0,
|
||||
},
|
||||
shouldSend: false,
|
||||
},
|
||||
@@ -373,9 +371,9 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
ThresholdBytes: 500,
|
||||
MaxBytes: 1000,
|
||||
},
|
||||
fetchInProgress: 1,
|
||||
},
|
||||
shouldSend: false,
|
||||
fetchInProgress: true,
|
||||
shouldSend: false,
|
||||
},
|
||||
{
|
||||
name: "StopAfter set, pending msgs below StopAfter, send pull request",
|
||||
@@ -388,8 +386,7 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
MaxMessages: 10,
|
||||
StopAfter: 8,
|
||||
},
|
||||
fetchInProgress: 0,
|
||||
delivered: 2,
|
||||
delivered: 2,
|
||||
},
|
||||
shouldSend: true,
|
||||
expectedPullRequest: &pullRequest{
|
||||
@@ -408,8 +405,7 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
MaxMessages: 10,
|
||||
StopAfter: 6,
|
||||
},
|
||||
fetchInProgress: 0,
|
||||
delivered: 0,
|
||||
delivered: 0,
|
||||
},
|
||||
shouldSend: false,
|
||||
},
|
||||
@@ -419,6 +415,9 @@ func TestPullConsumer_checkPending(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
prChan := make(chan *pullRequest, 1)
|
||||
test.givenSub.fetchNext = prChan
|
||||
if test.fetchInProgress {
|
||||
test.givenSub.fetchInProgress.Store(1)
|
||||
}
|
||||
errs := make(chan error, 1)
|
||||
ok := make(chan struct{}, 1)
|
||||
go func() {
|
||||
|
@@ -32,6 +32,7 @@ type (
|
||||
cfg *OrderedConsumerConfig
|
||||
stream string
|
||||
currentConsumer *pullConsumer
|
||||
currentSub ConsumeContext
|
||||
cursor cursor
|
||||
namePrefix string
|
||||
serial int
|
||||
@@ -116,19 +117,11 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
|
||||
}
|
||||
meta, err := msg.Metadata()
|
||||
if err != nil {
|
||||
sub, ok := c.currentConsumer.getSubscription("")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.errHandler(serial)(sub, err)
|
||||
c.errHandler(serial)(c.currentSub, err)
|
||||
return
|
||||
}
|
||||
dseq := meta.Sequence.Consumer
|
||||
if dseq != c.cursor.deliverSeq+1 {
|
||||
sub, ok := c.currentConsumer.getSubscription("")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.errHandler(serial)(sub, errOrderedSequenceMismatch)
|
||||
return
|
||||
}
|
||||
@@ -138,21 +131,18 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
|
||||
}
|
||||
}
|
||||
|
||||
_, err = c.currentConsumer.Consume(internalHandler(c.serial), opts...)
|
||||
cc, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.currentSub = cc
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.doReset:
|
||||
if err := c.reset(); err != nil {
|
||||
sub, ok := c.currentConsumer.getSubscription("")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.errHandler(c.serial)(sub, err)
|
||||
c.errHandler(c.serial)(c.currentSub, err)
|
||||
}
|
||||
if c.withStopAfter {
|
||||
select {
|
||||
@@ -175,12 +165,12 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
|
||||
if c.withStopAfter {
|
||||
opts = append(opts, consumeStopAfterNotify(c.stopAfter, c.stopAfterMsgsLeft))
|
||||
}
|
||||
if _, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil {
|
||||
sub, ok := c.currentConsumer.getSubscription("")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.errHandler(c.serial)(sub, err)
|
||||
if cc, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil {
|
||||
c.errHandler(c.serial)(cc, err)
|
||||
} else {
|
||||
c.Lock()
|
||||
c.currentSub = cc
|
||||
c.Unlock()
|
||||
}
|
||||
case <-sub.done:
|
||||
return
|
||||
@@ -250,10 +240,11 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er
|
||||
if c.stopAfter > 0 {
|
||||
opts = append(opts, messagesStopAfterNotify(c.stopAfter, c.stopAfterMsgsLeft))
|
||||
}
|
||||
_, err = c.currentConsumer.Messages(opts...)
|
||||
cc, err := c.currentConsumer.Messages(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.currentSub = cc
|
||||
|
||||
sub := &orderedSubscription{
|
||||
consumer: c,
|
||||
@@ -267,12 +258,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er
|
||||
|
||||
func (s *orderedSubscription) Next() (Msg, error) {
|
||||
for {
|
||||
currentConsumer := s.consumer.currentConsumer
|
||||
sub, ok := currentConsumer.getSubscription("")
|
||||
if !ok {
|
||||
return nil, ErrMsgIteratorClosed
|
||||
}
|
||||
msg, err := sub.Next()
|
||||
msg, err := s.consumer.currentSub.(*pullSubscription).Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrMsgIteratorClosed) {
|
||||
s.Stop()
|
||||
@@ -292,10 +278,11 @@ func (s *orderedSubscription) Next() (Msg, error) {
|
||||
if err := s.consumer.reset(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err := s.consumer.currentConsumer.Messages(s.opts...)
|
||||
cc, err := s.consumer.currentConsumer.Messages(s.opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.consumer.currentSub = cc
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -312,10 +299,11 @@ func (s *orderedSubscription) Next() (Msg, error) {
|
||||
if err := s.consumer.reset(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err := s.consumer.currentConsumer.Messages(s.opts...)
|
||||
cc, err := s.consumer.currentConsumer.Messages(s.opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.consumer.currentSub = cc
|
||||
continue
|
||||
}
|
||||
s.consumer.cursor.deliverSeq = dseq
|
||||
@@ -328,13 +316,9 @@ func (s *orderedSubscription) Stop() {
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return
|
||||
}
|
||||
sub, ok := s.consumer.currentConsumer.getSubscription("")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.consumer.currentConsumer.Lock()
|
||||
defer s.consumer.currentConsumer.Unlock()
|
||||
sub.Stop()
|
||||
s.consumer.Lock()
|
||||
defer s.consumer.Unlock()
|
||||
s.consumer.currentSub.Stop()
|
||||
close(s.done)
|
||||
}
|
||||
|
||||
@@ -342,13 +326,9 @@ func (s *orderedSubscription) Drain() {
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return
|
||||
}
|
||||
sub, ok := s.consumer.currentConsumer.getSubscription("")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.consumer.currentConsumer.Lock()
|
||||
defer s.consumer.currentConsumer.Unlock()
|
||||
sub.Drain()
|
||||
s.consumer.currentSub.Drain()
|
||||
close(s.done)
|
||||
}
|
||||
|
||||
@@ -495,10 +475,9 @@ func (c *orderedConsumer) reset() error {
|
||||
defer c.Unlock()
|
||||
defer atomic.StoreUint32(&c.resetInProgress, 0)
|
||||
if c.currentConsumer != nil {
|
||||
sub, ok := c.currentConsumer.getSubscription("")
|
||||
c.currentConsumer.Lock()
|
||||
if ok {
|
||||
sub.Stop()
|
||||
if c.currentSub != nil {
|
||||
c.currentSub.Stop()
|
||||
}
|
||||
consName := c.currentConsumer.CachedInfo().Name
|
||||
c.currentConsumer.Unlock()
|
||||
|
@@ -23,6 +23,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/internal/syncx"
|
||||
"github.com/nats-io/nuid"
|
||||
)
|
||||
|
||||
@@ -75,12 +76,12 @@ type (
|
||||
|
||||
pullConsumer struct {
|
||||
sync.Mutex
|
||||
jetStream *jetStream
|
||||
stream string
|
||||
durable bool
|
||||
name string
|
||||
info *ConsumerInfo
|
||||
subscriptions map[string]*pullSubscription
|
||||
jetStream *jetStream
|
||||
stream string
|
||||
durable bool
|
||||
name string
|
||||
info *ConsumerInfo
|
||||
subs syncx.Map[string, *pullSubscription]
|
||||
}
|
||||
|
||||
pullRequest struct {
|
||||
@@ -116,9 +117,9 @@ type (
|
||||
errs chan error
|
||||
pending pendingMsgs
|
||||
hbMonitor *hbMonitor
|
||||
fetchInProgress uint32
|
||||
closed uint32
|
||||
draining uint32
|
||||
fetchInProgress atomic.Uint32
|
||||
closed atomic.Uint32
|
||||
draining atomic.Uint32
|
||||
done chan struct{}
|
||||
connStatusChanged chan nats.Status
|
||||
fetchNext chan *pullRequest
|
||||
@@ -181,12 +182,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
|
||||
|
||||
subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name))
|
||||
|
||||
// for single consume, use empty string as id
|
||||
// this is useful for ordered consumer, where only a single subscription is valid
|
||||
var consumeID string
|
||||
if len(p.subscriptions) > 0 {
|
||||
consumeID = nuid.Next()
|
||||
}
|
||||
consumeID := nuid.Next()
|
||||
sub := &pullSubscription{
|
||||
id: consumeID,
|
||||
consumer: p,
|
||||
@@ -199,7 +195,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
|
||||
|
||||
sub.hbMonitor = sub.scheduleHeartbeatCheck(consumeOpts.Heartbeat)
|
||||
|
||||
p.subscriptions[sub.id] = sub
|
||||
p.subs.Store(sub.id, sub)
|
||||
p.Unlock()
|
||||
|
||||
internalHandler := func(msg *nats.Msg) {
|
||||
@@ -232,7 +228,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
|
||||
sub.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if atomic.LoadUint32(&sub.closed) == 1 {
|
||||
if sub.closed.Load() == 1 {
|
||||
return
|
||||
}
|
||||
if sub.consumeOpts.ErrHandler != nil {
|
||||
@@ -259,10 +255,8 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
|
||||
}
|
||||
sub.subscription.SetClosedHandler(func(sid string) func(string) {
|
||||
return func(subject string) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
delete(p.subscriptions, sid)
|
||||
atomic.CompareAndSwapUint32(&sub.draining, 1, 0)
|
||||
p.subs.Delete(sid)
|
||||
sub.draining.CompareAndSwap(1, 0)
|
||||
}
|
||||
}(sub.id))
|
||||
|
||||
@@ -286,7 +280,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
|
||||
go func() {
|
||||
isConnected := true
|
||||
for {
|
||||
if atomic.LoadUint32(&sub.closed) == 1 {
|
||||
if sub.closed.Load() == 1 {
|
||||
return
|
||||
}
|
||||
select {
|
||||
@@ -383,7 +377,7 @@ func (s *pullSubscription) incrementDeliveredMsgs() {
|
||||
func (s *pullSubscription) checkPending() {
|
||||
if (s.pending.msgCount < s.consumeOpts.ThresholdMessages ||
|
||||
(s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0)) &&
|
||||
atomic.LoadUint32(&s.fetchInProgress) == 0 {
|
||||
s.fetchInProgress.Load() == 0 {
|
||||
|
||||
var batchSize, maxBytes int
|
||||
if s.consumeOpts.MaxBytes == 0 {
|
||||
@@ -427,12 +421,7 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error
|
||||
|
||||
msgs := make(chan *nats.Msg, consumeOpts.MaxMessages)
|
||||
|
||||
// for single consume, use empty string as id
|
||||
// this is useful for ordered consumer, where only a single subscription is valid
|
||||
var consumeID string
|
||||
if len(p.subscriptions) > 0 {
|
||||
consumeID = nuid.Next()
|
||||
}
|
||||
consumeID := nuid.Next()
|
||||
sub := &pullSubscription{
|
||||
id: consumeID,
|
||||
consumer: p,
|
||||
@@ -451,20 +440,18 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error
|
||||
}
|
||||
sub.subscription.SetClosedHandler(func(sid string) func(string) {
|
||||
return func(subject string) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
if atomic.LoadUint32(&sub.draining) != 1 {
|
||||
if sub.draining.Load() != 1 {
|
||||
// if we're not draining, subscription can be closed as soon
|
||||
// as closed handler is called
|
||||
// otherwise, we need to wait until all messages are drained
|
||||
// in Next
|
||||
delete(p.subscriptions, sid)
|
||||
p.subs.Delete(sid)
|
||||
}
|
||||
close(msgs)
|
||||
}
|
||||
}(sub.id))
|
||||
|
||||
p.subscriptions[sub.id] = sub
|
||||
p.subs.Store(sub.id, sub)
|
||||
p.Unlock()
|
||||
|
||||
go sub.pullMessages(subject)
|
||||
@@ -502,8 +489,8 @@ var (
|
||||
func (s *pullSubscription) Next() (Msg, error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
drainMode := atomic.LoadUint32(&s.draining) == 1
|
||||
closed := atomic.LoadUint32(&s.closed) == 1
|
||||
drainMode := s.draining.Load() == 1
|
||||
closed := s.closed.Load() == 1
|
||||
if closed && !drainMode {
|
||||
return nil, ErrMsgIteratorClosed
|
||||
}
|
||||
@@ -526,8 +513,8 @@ func (s *pullSubscription) Next() (Msg, error) {
|
||||
case msg, ok := <-s.msgs:
|
||||
if !ok {
|
||||
// if msgs channel is closed, it means that subscription was either drained or stopped
|
||||
delete(s.consumer.subscriptions, s.id)
|
||||
atomic.CompareAndSwapUint32(&s.draining, 1, 0)
|
||||
s.consumer.subs.Delete(s.id)
|
||||
s.draining.CompareAndSwap(1, 0)
|
||||
return nil, ErrMsgIteratorClosed
|
||||
}
|
||||
if hbMonitor != nil {
|
||||
@@ -630,7 +617,7 @@ func (hb *hbMonitor) Reset(dur time.Duration) {
|
||||
// Next after calling Stop will return ErrMsgIteratorClosed error.
|
||||
// All messages that are already in the buffer are discarded.
|
||||
func (s *pullSubscription) Stop() {
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
if !s.closed.CompareAndSwap(0, 1) {
|
||||
return
|
||||
}
|
||||
close(s.done)
|
||||
@@ -648,10 +635,10 @@ func (s *pullSubscription) Stop() {
|
||||
// subsequent calls to Next. After the buffer is drained, Next will
|
||||
// return ErrMsgIteratorClosed error.
|
||||
func (s *pullSubscription) Drain() {
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
if !s.closed.CompareAndSwap(0, 1) {
|
||||
return
|
||||
}
|
||||
atomic.StoreUint32(&s.draining, 1)
|
||||
s.draining.Store(1)
|
||||
close(s.done)
|
||||
if s.consumeOpts.stopAfterMsgsLeft != nil {
|
||||
if s.delivered >= s.consumeOpts.StopAfter {
|
||||
@@ -840,7 +827,7 @@ func (s *pullSubscription) pullMessages(subject string) {
|
||||
for {
|
||||
select {
|
||||
case req := <-s.fetchNext:
|
||||
atomic.StoreUint32(&s.fetchInProgress, 1)
|
||||
s.fetchInProgress.Store(1)
|
||||
|
||||
if err := s.pull(req, subject); err != nil {
|
||||
if errors.Is(err, ErrMsgIteratorClosed) {
|
||||
@@ -849,7 +836,7 @@ func (s *pullSubscription) pullMessages(subject string) {
|
||||
}
|
||||
s.errs <- err
|
||||
}
|
||||
atomic.StoreUint32(&s.fetchInProgress, 0)
|
||||
s.fetchInProgress.Store(0)
|
||||
case <-s.done:
|
||||
s.cleanup()
|
||||
return
|
||||
@@ -880,13 +867,13 @@ func (s *pullSubscription) cleanup() {
|
||||
if s.hbMonitor != nil {
|
||||
s.hbMonitor.Stop()
|
||||
}
|
||||
drainMode := atomic.LoadUint32(&s.draining) == 1
|
||||
drainMode := s.draining.Load() == 1
|
||||
if drainMode {
|
||||
s.subscription.Drain()
|
||||
} else {
|
||||
s.subscription.Unsubscribe()
|
||||
}
|
||||
atomic.StoreUint32(&s.closed, 1)
|
||||
s.closed.Store(1)
|
||||
}
|
||||
|
||||
// pull sends a pull request to the server and waits for messages using a subscription from [pullSubscription].
|
||||
@@ -894,7 +881,7 @@ func (s *pullSubscription) cleanup() {
|
||||
func (s *pullSubscription) pull(req *pullRequest, subject string) error {
|
||||
s.consumer.Lock()
|
||||
defer s.consumer.Unlock()
|
||||
if atomic.LoadUint32(&s.closed) == 1 {
|
||||
if s.closed.Load() == 1 {
|
||||
return ErrMsgIteratorClosed
|
||||
}
|
||||
if req.Batch < 1 {
|
||||
@@ -994,10 +981,3 @@ func (consumeOpts *consumeOpts) setDefaults(ordered bool) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *pullConsumer) getSubscription(id string) (*pullSubscription, bool) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
sub, ok := c.subscriptions[id]
|
||||
return sub, ok
|
||||
}
|
||||
|
Reference in New Issue
Block a user