diff --git a/.golangci.yaml b/.golangci.yaml index 4273d48..e014913 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -5,4 +5,6 @@ issues: - linters: - errcheck text: "Unsubscribe" - path: jsv2/jetstream/consumer.go \ No newline at end of file + - linters: + - errcheck + text: "msg.Ack" \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index a670b71..a5f06f4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,7 +19,7 @@ before_script: find . -type f -name "*.go" | xargs misspell -error -locale US; GOFLAGS="-mod=mod -modfile=go_test.mod" staticcheck ./...; fi -- golangci-lint run ./jsv2/... +- golangci-lint run ./jetstream/... script: - go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... --failfast -vet=off - if [[ "$TRAVIS_GO_VERSION" =~ 1.20 ]]; then ./scripts/cov.sh TRAVIS; else go test -modfile=go_test.mod -race -v -p=1 ./... --failfast -vet=off; fi diff --git a/examples/jsv2/js-consume/main.go b/examples/jsv2/js-consume/main.go new file mode 100644 index 0000000..80c648e --- /dev/null +++ b/examples/jsv2/js-consume/main.go @@ -0,0 +1,88 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + + nc, err := nats.Connect("nats://127.0.0.1:4222") + if err != nil { + log.Fatal(err) + } + + js, err := jetstream.New(nc) + if err != nil { + log.Fatal(err) + } + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, + }) + if err != nil { + log.Fatal(err) + } + + cons, err := s.AddConsumer(ctx, jetstream.ConsumerConfig{ + Durable: "TestConsumerConsume", + AckPolicy: jetstream.AckExplicitPolicy, + }) + if err != nil { + log.Fatal(err) + } + go endlessPublish(ctx, nc, js) + + cc, err := cons.Consume(func(msg jetstream.Msg) { + fmt.Println(string(msg.Data())) + msg.Ack() + }, jetstream.ConsumeErrHandler(func(consumeCtx jetstream.ConsumeContext, err error) { + fmt.Println(err) + })) + if err != nil { + log.Fatal(err) + } + defer cc.Stop() + + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + <-sig + +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } +} diff --git a/examples/jsv2/js-fetch/main.go b/examples/jsv2/js-fetch/main.go new file mode 100644 index 0000000..7177821 --- /dev/null +++ b/examples/jsv2/js-fetch/main.go @@ -0,0 +1,83 @@ +// Copyright 2020-2022 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + + nc, err := nats.Connect("nats://127.0.0.1:4222") + if err != nil { + log.Fatal(err) + } + + js, err := jetstream.New(nc) + if err != nil { + log.Fatal(err) + } + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, + }) + if err != nil { + log.Fatal(err) + } + + cons, err := s.AddConsumer(ctx, jetstream.ConsumerConfig{ + Durable: "TestConsumerListener", + AckPolicy: jetstream.AckExplicitPolicy, + }) + if err != nil { + log.Fatal(err) + } + go endlessPublish(ctx, nc, js) + + for { + msgs, err := cons.Fetch(100, jetstream.FetchMaxWait(1*time.Second)) + if err != nil { + fmt.Println(err) + } + for msg := range msgs.Messages() { + fmt.Println(string(msg.Data())) + msg.Ack() + } + if msgs.Error() != nil { + fmt.Println("Error fetching messages: ", err) + } + } +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } +} diff --git a/examples/jsv2/js-listener/main.go b/examples/jsv2/js-messages/main.go similarity index 53% rename from examples/jsv2/js-listener/main.go rename to examples/jsv2/js-messages/main.go index 7657122..4d9fad6 100644 --- a/examples/jsv2/js-listener/main.go +++ b/examples/jsv2/js-messages/main.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -20,11 +20,11 @@ import ( "time" "github.com/nats-io/nats.go" - "github.com/nats-io/nats.go/jsv2/jetstream" + "github.com/nats-io/nats.go/jetstream" ) func main() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) defer cancel() nc, err := nats.Connect("127.0.0.1:4222") @@ -36,26 +36,47 @@ func main() { if err != nil { log.Fatal(err) } - s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "TEST_STREAM", Subjects: []string{"FOO.*"}}) - if err != nil { - log.Fatal(err) - } - - cons, err := s.CreateConsumer(ctx, jetstream.ConsumerConfig{Durable: "TestConsumerListener", AckPolicy: jetstream.AckExplicitPolicy}) - if err != nil { - log.Fatal(err) - } - - l, err := cons.Listener(func(msg jetstream.Msg, err error) { - if err != nil { - log.Fatal(err) - } - fmt.Println(string(msg.Data())) - msg.Ack() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, }) if err != nil { log.Fatal(err) } - defer l.Stop() + cons, err := s.AddConsumer(ctx, jetstream.ConsumerConfig{ + Durable: "TestConsumerMessages", + AckPolicy: jetstream.AckExplicitPolicy, + }) + if err != nil { + log.Fatal(err) + } + go endlessPublish(ctx, nc, js) + + it, err := cons.Messages(jetstream.PullMaxMessages(1)) + if err != nil { + log.Fatal(err) + } + for { + msg, err := it.Next() + if err != nil { + fmt.Println("next err: ", err) + } + fmt.Println(string(msg.Data())) + msg.Ack() + } +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } } diff --git a/examples/jsv2/js-next/main.go b/examples/jsv2/js-next/main.go index 33b3431..649d8ea 100644 --- a/examples/jsv2/js-next/main.go +++ b/examples/jsv2/js-next/main.go @@ -20,46 +20,60 @@ import ( "time" "github.com/nats-io/nats.go" - "github.com/nats-io/nats.go/jsv2/jetstream" + "github.com/nats-io/nats.go/jetstream" ) func main() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) defer cancel() - nc, err := nats.Connect("127.0.0.1:4222") + nc, err := nats.Connect("nats://127.0.0.1:4222") if err != nil { log.Fatal(err) } - defer nc.Flush() js, err := jetstream.New(nc) if err != nil { log.Fatal(err) } - - s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "TEST_STREAM", Subjects: []string{"FOO.*"}}) + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, + }) if err != nil { log.Fatal(err) } - cons, err := s.CreateConsumer(ctx, jetstream.ConsumerConfig{Durable: "TestConsumerReader", AckPolicy: jetstream.AckExplicitPolicy}) + cons, err := s.AddConsumer(ctx, jetstream.ConsumerConfig{ + Durable: "TestConsumerListener", + AckPolicy: jetstream.AckExplicitPolicy, + }) if err != nil { log.Fatal(err) } + go endlessPublish(ctx, nc, js) - reader, err := cons.Reader() - if err != nil { - log.Fatal(err) - } - for i := 0; i < 10; i++ { - msg, err := reader.Next() + for { + msg, err := cons.Next() if err != nil { - log.Fatal(err) + fmt.Println(err) + continue } fmt.Println(string(msg.Data())) msg.Ack() } - - reader.Stop() +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } } diff --git a/examples/jsv2/js-ordered-consume/main.go b/examples/jsv2/js-ordered-consume/main.go new file mode 100644 index 0000000..e44e08d --- /dev/null +++ b/examples/jsv2/js-ordered-consume/main.go @@ -0,0 +1,83 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + + nc, err := nats.Connect("nats://127.0.0.1:4222") + if err != nil { + log.Fatal(err) + } + + js, err := jetstream.New(nc) + if err != nil { + log.Fatal(err) + } + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, + }) + if err != nil { + log.Fatal(err) + } + + cons, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{ + MaxResetAttempts: 5, + }) + if err != nil { + log.Fatal(err) + } + go endlessPublish(ctx, nc, js) + + _, err = cons.Consume(func(msg jetstream.Msg) { + fmt.Println(string(msg.Data())) + msg.Ack() + }) + if err != nil { + log.Fatal(err) + } + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + <-sig + +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } +} diff --git a/examples/jsv2/js-ordered-fetch/main.go b/examples/jsv2/js-ordered-fetch/main.go new file mode 100644 index 0000000..6b62df8 --- /dev/null +++ b/examples/jsv2/js-ordered-fetch/main.go @@ -0,0 +1,82 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + + nc, err := nats.Connect("nats://127.0.0.1:4222") + if err != nil { + log.Fatal(err) + } + + js, err := jetstream.New(nc) + if err != nil { + log.Fatal(err) + } + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, + }) + if err != nil { + log.Fatal(err) + } + + cons, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{ + MaxResetAttempts: 5, + }) + if err != nil { + log.Fatal(err) + } + go endlessPublish(ctx, nc, js) + + for { + msgs, err := cons.Fetch(100) + if err != nil { + fmt.Println(err) + } + for msg := range msgs.Messages() { + fmt.Println(string(msg.Data())) + msg.Ack() + } + if msgs.Error() != nil { + fmt.Println("Error fetching messages: ", err) + } + } +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } +} diff --git a/examples/jsv2/js-ordered-messages/main.go b/examples/jsv2/js-ordered-messages/main.go new file mode 100644 index 0000000..72a7807 --- /dev/null +++ b/examples/jsv2/js-ordered-messages/main.go @@ -0,0 +1,82 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + + nc, err := nats.Connect("nats://127.0.0.1:4222") + if err != nil { + log.Fatal(err) + } + + js, err := jetstream.New(nc) + if err != nil { + log.Fatal(err) + } + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, + }) + if err != nil { + log.Fatal(err) + } + + cons, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{ + MaxResetAttempts: 5, + }) + if err != nil { + log.Fatal(err) + } + go endlessPublish(ctx, nc, js) + + it, err := cons.Messages() + if err != nil { + log.Fatal(err) + } + defer it.Stop() + for { + msg, err := it.Next() + if err != nil { + fmt.Println(err) + } + fmt.Println(string(msg.Data())) + msg.Ack() + } +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } +} diff --git a/examples/jsv2/js-parallel-consume/main.go b/examples/jsv2/js-parallel-consume/main.go new file mode 100644 index 0000000..9f412f6 --- /dev/null +++ b/examples/jsv2/js-parallel-consume/main.go @@ -0,0 +1,92 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + + nc, err := nats.Connect("nats://127.0.0.1:4222") + if err != nil { + log.Fatal(err) + } + + js, err := jetstream.New(nc) + if err != nil { + log.Fatal(err) + } + s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"FOO.*"}, + }) + if err != nil { + log.Fatal(err) + } + + cons, err := s.AddConsumer(ctx, jetstream.ConsumerConfig{ + Durable: "TestConsumerParallelConsume", + AckPolicy: jetstream.AckExplicitPolicy, + }) + if err != nil { + log.Fatal(err) + } + go endlessPublish(ctx, nc, js) + + for i := 0; i < 5; i++ { + cc, err := cons.Consume(func(consumeID int) jetstream.MessageHandler { + return func(msg jetstream.Msg) { + fmt.Printf("Received msg on consume %d\n", consumeID) + msg.Ack() + } + }(i), jetstream.ConsumeErrHandler(func(consumeCtx jetstream.ConsumeContext, err error) { + fmt.Println(err) + })) + if err != nil { + log.Fatal(err) + } + defer cc.Stop() + } + + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + <-sig + +} + +func endlessPublish(ctx context.Context, nc *nats.Conn, js jetstream.JetStream) { + var i int + for { + time.Sleep(500 * time.Millisecond) + if nc.Status() != nats.CONNECTED { + continue + } + if _, err := js.Publish(ctx, "FOO.TEST1", []byte(fmt.Sprintf("msg %d", i))); err != nil { + fmt.Println("pub error: ", err) + } + i++ + } +} diff --git a/internal/parser/parse_test.go b/internal/parser/parse_test.go index 2dbf284..602e56e 100644 --- a/internal/parser/parse_test.go +++ b/internal/parser/parse_test.go @@ -18,6 +18,7 @@ import ( "math" "reflect" "strconv" + "strings" "testing" ) @@ -79,26 +80,25 @@ func TestParseNum(t *testing.T) { } } -// TODO: Add this test once CI uses go 1.18 -// func FuzzParseNum(f *testing.F) { -// testcases := []string{"191817", " ", "-123", "abc"} -// for _, tc := range testcases { -// f.Add(tc) -// } +func FuzzParseNum(f *testing.F) { + testcases := []string{"191817", " ", "-123", "abc"} + for _, tc := range testcases { + f.Add(tc) + } -// f.Fuzz(func(t *testing.T, given string) { -// given = strings.TrimLeft(given, "+") -// res := ParseNum(given) -// parsed, err := strconv.ParseUint(given, 10, 64) -// if err != nil && !errors.Is(err, strconv.ErrRange) { -// if res != 0 { -// t.Errorf("given: %s; expected: -1; got: %d; err: %v", given, res, err) -// } -// } else if err == nil && res != parsed { -// t.Errorf("given: %s; expected: %d; got: %d", given, parsed, res) -// } -// }) -// } + f.Fuzz(func(t *testing.T, given string) { + given = strings.TrimLeft(given, "+") + res := ParseNum(given) + parsed, err := strconv.ParseUint(given, 10, 64) + if err != nil && !errors.Is(err, strconv.ErrRange) { + if res != 0 { + t.Errorf("given: %s; expected: -1; got: %d; err: %v", given, res, err) + } + } else if err == nil && res != parsed { + t.Errorf("given: %s; expected: %d; got: %d", given, parsed, res) + } + }) +} func TestGetMetadataFields(t *testing.T) { tests := []struct { diff --git a/jsv2/README.md b/jetstream/README.md similarity index 99% rename from jsv2/README.md rename to jetstream/README.md index d189bd1..7f19b62 100644 --- a/jsv2/README.md +++ b/jetstream/README.md @@ -46,7 +46,7 @@ import ( "time" "github.com/nats-io/nats.go" - "github.com/nats-io/nats.go/jsv2/jetstream" + "github.com/nats-io/nats.go/jetstream" ) func main() { diff --git a/jsv2/jetstream/api.go b/jetstream/api.go similarity index 93% rename from jsv2/jetstream/api.go rename to jetstream/api.go index d3e3fcd..2d009d1 100644 --- a/jsv2/jetstream/api.go +++ b/jetstream/api.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -45,10 +45,11 @@ const ( apiAccountInfo = "INFO" // apiConsumerCreateT is used to create consumers. - apiConsumerCreateT = "CONSUMER.CREATE.%s" + apiConsumerCreateT = "CONSUMER.CREATE.%s.%s" - // apiDurableCreateT is used to create durable consumers. - apiDurableCreateT = "CONSUMER.DURABLE.CREATE.%s.%s" + // apiConsumerCreateT is used to create consumers. + // it accepts stream name, consumer name and filter subject + apiConsumerCreateWithFilterSubjectT = "CONSUMER.CREATE.%s.%s.%s" // apiConsumerInfoT is used to create consumers. apiConsumerInfoT = "CONSUMER.INFO.%s.%s" diff --git a/jetstream/consumer.go b/jetstream/consumer.go new file mode 100644 index 0000000..1ab8f87 --- /dev/null +++ b/jetstream/consumer.go @@ -0,0 +1,199 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jetstream + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "strings" + + "github.com/nats-io/nuid" +) + +type ( + + // Consumer contains methods for fetching/processing messages from a stream, as well as fetching consumer info + Consumer interface { + // Fetch is used to retrieve up to a provided number of messages from a stream. + // This method will always send a single request and wait until either all messages are retreived + // or request times out. + Fetch(int, ...FetchOpt) (MessageBatch, error) + // FetchBytes is used to retrieve up to a provided bytes from the stream. + // This method will always send a single request and wait until provided number of bytes is + // exceeded or request times out. + FetchBytes(int, ...FetchOpt) (MessageBatch, error) + // FetchNoWait is used to retrieve up to a provided number of messages from a stream. + // This method will always send a single request and immediately return up to a provided number of messages. + FetchNoWait(batch int) (MessageBatch, error) + // Consume can be used to continuously receive messages and handle them with the provided callback function + Consume(MessageHandler, ...PullConsumeOpt) (ConsumeContext, error) + // Messages returns [MessagesContext], allowing continuously iterating over messages on a stream. + Messages(...PullMessagesOpt) (MessagesContext, error) + // Next is used to retrieve the next message from the stream. + // This method will block until the message is retrieved or timeout is reached. + Next(...FetchOpt) (Msg, error) + + // Info returns Consumer details + Info(context.Context) (*ConsumerInfo, error) + // CachedInfo returns [*ConsumerInfo] cached on a consumer struct + CachedInfo() *ConsumerInfo + } +) + +// Info returns [ConsumerInfo] for a given consumer +func (p *pullConsumer) Info(ctx context.Context) (*ConsumerInfo, error) { + infoSubject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiConsumerInfoT, p.stream, p.name)) + var resp consumerInfoResponse + + if _, err := p.jetStream.apiRequestJSON(ctx, infoSubject, &resp); err != nil { + return nil, err + } + if resp.Error != nil { + if resp.Error.ErrorCode == JSErrCodeConsumerNotFound { + return nil, ErrConsumerNotFound + } + return nil, resp.Error + } + + p.info = resp.ConsumerInfo + return resp.ConsumerInfo, nil +} + +// CachedInfo returns [ConsumerInfo] fetched when initializing/updating a consumer +// +// NOTE: The returned object might not be up to date with the most recent updates on the server +// For up-to-date information, use [Info] +func (p *pullConsumer) CachedInfo() *ConsumerInfo { + return p.info +} + +func upsertConsumer(ctx context.Context, js *jetStream, stream string, cfg ConsumerConfig) (Consumer, error) { + req := createConsumerRequest{ + Stream: stream, + Config: &cfg, + } + reqJSON, err := json.Marshal(req) + if err != nil { + return nil, err + } + + consumerName := cfg.Name + if consumerName == "" { + if cfg.Durable != "" { + consumerName = cfg.Durable + } else { + consumerName = generateConsName() + } + } + if err := validateConsumerName(consumerName); err != nil { + return nil, err + } + + var ccSubj string + if cfg.FilterSubject != "" { + ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerCreateWithFilterSubjectT, stream, consumerName, cfg.FilterSubject)) + } else { + ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerCreateT, stream, consumerName)) + } + var resp consumerInfoResponse + + if _, err := js.apiRequestJSON(ctx, ccSubj, &resp, reqJSON); err != nil { + return nil, err + } + if resp.Error != nil { + if resp.Error.ErrorCode == JSErrCodeStreamNotFound { + return nil, ErrStreamNotFound + } + return nil, resp.Error + } + + return &pullConsumer{ + jetStream: js, + stream: stream, + name: resp.Name, + durable: cfg.Durable != "", + info: resp.ConsumerInfo, + subscriptions: make(map[string]*pullSubscription), + }, nil +} + +func generateConsName() string { + name := nuid.Next() + sha := sha256.New() + sha.Write([]byte(name)) + b := sha.Sum(nil) + for i := 0; i < 8; i++ { + b[i] = rdigits[int(b[i]%base)] + } + return string(b[:8]) +} + +func getConsumer(ctx context.Context, js *jetStream, stream, name string) (Consumer, error) { + if err := validateConsumerName(name); err != nil { + return nil, err + } + infoSubject := apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerInfoT, stream, name)) + + var resp consumerInfoResponse + + if _, err := js.apiRequestJSON(ctx, infoSubject, &resp); err != nil { + return nil, err + } + if resp.Error != nil { + if resp.Error.ErrorCode == JSErrCodeConsumerNotFound { + return nil, ErrConsumerNotFound + } + return nil, resp.Error + } + + cons := &pullConsumer{ + jetStream: js, + stream: stream, + name: name, + durable: resp.Config.Durable != "", + info: resp.ConsumerInfo, + subscriptions: make(map[string]*pullSubscription, 0), + } + + return cons, nil +} + +func deleteConsumer(ctx context.Context, js *jetStream, stream, consumer string) error { + if err := validateConsumerName(consumer); err != nil { + return err + } + deleteSubject := apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerDeleteT, stream, consumer)) + + var resp consumerDeleteResponse + + if _, err := js.apiRequestJSON(ctx, deleteSubject, &resp); err != nil { + return err + } + if resp.Error != nil { + if resp.Error.ErrorCode == JSErrCodeConsumerNotFound { + return ErrConsumerNotFound + } + return resp.Error + } + return nil +} + +func validateConsumerName(dur string) error { + if strings.Contains(dur, ".") { + return fmt.Errorf("%w: '%s'", ErrInvalidConsumerName, dur) + } + return nil +} diff --git a/jsv2/jetstream/consumer_config.go b/jetstream/consumer_config.go similarity index 78% rename from jsv2/jetstream/consumer_config.go rename to jetstream/consumer_config.go index 7640c8d..3f11021 100644 --- a/jsv2/jetstream/consumer_config.go +++ b/jetstream/consumer_config.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -38,24 +38,27 @@ type ( // ConsumerConfig is the configuration of a JetStream consumer. ConsumerConfig struct { - Durable string `json:"durable_name,omitempty"` - Description string `json:"description,omitempty"` - DeliverPolicy DeliverPolicy `json:"deliver_policy"` - OptStartSeq uint64 `json:"opt_start_seq,omitempty"` - OptStartTime *time.Time `json:"opt_start_time,omitempty"` - AckPolicy AckPolicy `json:"ack_policy"` - AckWait time.Duration `json:"ack_wait,omitempty"` - MaxDeliver int `json:"max_deliver,omitempty"` - BackOff []time.Duration `json:"backoff,omitempty"` - FilterSubject string `json:"filter_subject,omitempty"` - ReplayPolicy ReplayPolicy `json:"replay_policy"` - RateLimit uint64 `json:"rate_limit_bps,omitempty"` // Bits per sec - SampleFrequency string `json:"sample_freq,omitempty"` - MaxWaiting int `json:"max_waiting,omitempty"` - MaxAckPending int `json:"max_ack_pending,omitempty"` - FlowControl bool `json:"flow_control,omitempty"` - Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` - HeadersOnly bool `json:"headers_only,omitempty"` + Name string `json:"name,omitempty"` + Durable string `json:"durable_name,omitempty"` + Description string `json:"description,omitempty"` + DeliverPolicy DeliverPolicy `json:"deliver_policy"` + OptStartSeq uint64 `json:"opt_start_seq,omitempty"` + OptStartTime *time.Time `json:"opt_start_time,omitempty"` + AckPolicy AckPolicy `json:"ack_policy"` + AckWait time.Duration `json:"ack_wait,omitempty"` + MaxDeliver int `json:"max_deliver,omitempty"` + BackOff []time.Duration `json:"backoff,omitempty"` + FilterSubjects []string `json:"filter_subjects,omitempty"` + FilterSubject string `json:"filter_subject,omitempty"` + ReplayPolicy ReplayPolicy `json:"replay_policy"` + RateLimit uint64 `json:"rate_limit_bps,omitempty"` // Bits per sec + SampleFrequency string `json:"sample_freq,omitempty"` + MaxWaiting int `json:"max_waiting,omitempty"` + MaxAckPending int `json:"max_ack_pending,omitempty"` + FlowControl bool `json:"flow_control,omitempty"` + Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` + HeadersOnly bool `json:"headers_only,omitempty"` + InactiveThreshold time.Duration `json:"inactive_threshold,omitempty"` // Pull based options. MaxRequestBatch int `json:"max_batch,omitempty"` @@ -65,15 +68,25 @@ type ( DeliverSubject string `json:"deliver_subject,omitempty"` DeliverGroup string `json:"deliver_group,omitempty"` - // Ephemeral inactivity threshold. - InactiveThreshold time.Duration `json:"inactive_threshold,omitempty"` - // Generally inherited by parent stream and other markers, now can be configured directly. Replicas int `json:"num_replicas"` // Force memory storage. MemoryStorage bool `json:"mem_storage,omitempty"` } + OrderedConsumerConfig struct { + FilterSubjects []string `json:"filter_subjects,omitempty"` + DeliverPolicy DeliverPolicy `json:"deliver_policy"` + OptStartSeq uint64 `json:"opt_start_seq,omitempty"` + OptStartTime *time.Time `json:"opt_start_time,omitempty"` + ReplayPolicy ReplayPolicy `json:"replay_policy"` + InactiveThreshold time.Duration `json:"inactive_threshold,omitempty"` + + // Maximum number of attempts for the consumer to be recreated + // Defaults to unlimited + MaxResetAttempts int + } + DeliverPolicy int // AckPolicy determines how the consumer should acknowledge delivered messages. diff --git a/jsv2/jetstream/consumer_test.go b/jetstream/consumer_test.go similarity index 95% rename from jsv2/jetstream/consumer_test.go rename to jetstream/consumer_test.go index 7fdf63e..7f53903 100644 --- a/jsv2/jetstream/consumer_test.go +++ b/jetstream/consumer_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -43,7 +43,7 @@ func TestConsumerInfo(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{ + c, err := s.AddConsumer(ctx, ConsumerConfig{ Durable: "cons", AckPolicy: AckExplicitPolicy, Description: "test consumer", @@ -65,7 +65,7 @@ func TestConsumerInfo(t *testing.T) { } // update consumer and see if info is updated - _, err = s.UpdateConsumer(ctx, ConsumerConfig{ + _, err = s.AddConsumer(ctx, ConsumerConfig{ Durable: "cons", AckPolicy: AckExplicitPolicy, Description: "updated consumer", @@ -139,7 +139,7 @@ func TestConsumerCachedInfo(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{ + c, err := s.AddConsumer(ctx, ConsumerConfig{ Durable: "cons", AckPolicy: AckExplicitPolicy, Description: "test consumer", @@ -158,7 +158,7 @@ func TestConsumerCachedInfo(t *testing.T) { } // update consumer and see if info is updated - _, err = s.UpdateConsumer(ctx, ConsumerConfig{ + _, err = s.AddConsumer(ctx, ConsumerConfig{ Durable: "cons", AckPolicy: AckExplicitPolicy, Description: "updated consumer", diff --git a/jsv2/jetstream/errors.go b/jetstream/errors.go similarity index 94% rename from jsv2/jetstream/errors.go rename to jetstream/errors.go index c2c655f..7d6c160 100644 --- a/jsv2/jetstream/errors.go +++ b/jetstream/errors.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -49,6 +49,7 @@ const ( JSErrCodeStreamNotFound ErrorCode = 10059 JSErrCodeStreamNameInUse ErrorCode = 10058 + JSErrCodeConsumerCreate ErrorCode = 10012 JSErrCodeConsumerNotFound ErrorCode = 10014 JSErrCodeConsumerNameExists ErrorCode = 10013 JSErrCodeConsumerAlreadyExists ErrorCode = 10105 @@ -82,6 +83,9 @@ var ( // ErrBadRequest is returned when invalid request is sent to JetStream API. ErrBadRequest JetStreamError = &jsError{apiErr: &APIError{ErrorCode: JSErrCodeBadRequest, Description: "bad request", Code: 400}} + // ErrConsumerCreate is returned when nats-server reports error when creating consumer (e.g. illegal update). + ErrConsumerCreate JetStreamError = &jsError{apiErr: &APIError{ErrorCode: JSErrCodeConsumerCreate, Description: "could not create consumer", Code: 500}} + // Client errors // ErrConsumerNotFound is an error returned when consumer with given name does not exist. @@ -155,6 +159,9 @@ var ( // ErrMsgIteratorClosed is returned when attempting to get message from a closed iterator ErrMsgIteratorClosed = &jsError{message: "messages iterator closed"} + + ErrOrderedConsumerReset = &jsError{message: "recreating ordered consumer"} + ErrOrderedSequenceMismatch = &jsError{message: "sequence mismatch"} ) // Error prints the JetStream API error code and description diff --git a/jsv2/jetstream/errors_test.go b/jetstream/errors_test.go similarity index 99% rename from jsv2/jetstream/errors_test.go rename to jetstream/errors_test.go index 2f150cc..411a6db 100644 --- a/jsv2/jetstream/errors_test.go +++ b/jetstream/errors_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/jsv2/jetstream/helper_test.go b/jetstream/helper_test.go similarity index 99% rename from jsv2/jetstream/helper_test.go rename to jetstream/helper_test.go index 5a9128b..090bfae 100644 --- a/jsv2/jetstream/helper_test.go +++ b/jetstream/helper_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/jsv2/jetstream/jetstream.go b/jetstream/jetstream.go similarity index 91% rename from jsv2/jetstream/jetstream.go rename to jetstream/jetstream.go index a8cfe8b..c6441de 100644 --- a/jsv2/jetstream/jetstream.go +++ b/jetstream/jetstream.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -21,6 +21,7 @@ import ( "strings" "github.com/nats-io/nats.go" + "github.com/nats-io/nuid" ) type ( @@ -29,7 +30,7 @@ type ( // Create, update and get operations return 'Stream' interface, // allowing operations on consumers // - // CreateConsumer, Consumer and DeleteConsumer are helper methods used to create/fetch/remove consumer without fetching stream (bypassing stream API) + // AddConsumer, Consumer and DeleteConsumer are helper methods used to create/fetch/remove consumer without fetching stream (bypassing stream API) // // Client returns a JetStremClient, used to publish messages on a stream or fetch messages by sequence number JetStream interface { @@ -76,12 +77,15 @@ type ( } StreamConsumerManager interface { - // CreateConsumer creates a consumer on a given stream with given config - // This operation is idempotent - if a consumer already exists, it will be a no-op (or error if configs do not match) + // AddConsumer creates a consumer on a given stream with given config. + // If consumer already exists, it will be updated (if possible). // Consumer interface is returned, serving as a hook to operate on a consumer (e.g. fetch messages) - CreateConsumer(context.Context, string, ConsumerConfig) (Consumer, error) - // UpdateConsumer updates an existing consumer - UpdateConsumer(context.Context, string, ConsumerConfig) (Consumer, error) + AddConsumer(context.Context, string, ConsumerConfig) (Consumer, error) + // OrderedConsumer returns an OrderedConsumer instance. + // OrderedConsumer allows fetching messages from a stream (just like standard consumer), + // for in order delivery of messages. Underlying consumer is re-created when necessary, + // without additional client code. + OrderedConsumer(context.Context, string, OrderedConsumerConfig) (Consumer, error) // Consumer returns a hook to an existing consumer, allowing processing of messages Consumer(context.Context, string, string) (Consumer, error) // DeleteConsumer removes a consumer with given name from a stream @@ -368,41 +372,32 @@ func (js *jetStream) DeleteStream(ctx context.Context, name string) error { return nil } -// CreateConsumer creates a consumer on a given stream with given config +// AddConsumer creates a consumer on a given stream with given config // This operation is idempotent - if a consumer already exists, it will be a no-op (or error if configs do not match) // Consumer interface is returned, serving as a hook to operate on a consumer (e.g. fetch messages) -func (js *jetStream) CreateConsumer(ctx context.Context, stream string, cfg ConsumerConfig) (Consumer, error) { +func (js *jetStream) AddConsumer(ctx context.Context, stream string, cfg ConsumerConfig) (Consumer, error) { if err := validateStreamName(stream); err != nil { return nil, err } - if cfg.Durable != "" { - c, err := js.Consumer(ctx, stream, cfg.Durable) - if err != nil && !errors.Is(err, ErrConsumerNotFound) { - return nil, err - } - if c != nil { - if err := compareConsumerConfig(&c.CachedInfo().Config, &cfg); err != nil { - return nil, fmt.Errorf("%w: %s", ErrConsumerNameAlreadyInUse, cfg.Durable) - } - return c, nil - } - } return upsertConsumer(ctx, js, stream, cfg) } -// UpdateConsumer updates an existing consumer -func (js *jetStream) UpdateConsumer(ctx context.Context, stream string, cfg ConsumerConfig) (Consumer, error) { +func (js *jetStream) OrderedConsumer(ctx context.Context, stream string, cfg OrderedConsumerConfig) (Consumer, error) { if err := validateStreamName(stream); err != nil { return nil, err } - if cfg.Durable == "" { - return nil, ErrConsumerNameRequired + oc := &orderedConsumer{ + jetStream: js, + cfg: &cfg, + stream: stream, + namePrefix: nuid.Next(), + doReset: make(chan struct{}, 1), } - _, err := js.Consumer(ctx, stream, cfg.Durable) - if err != nil { - return nil, err + if cfg.OptStartSeq != 0 { + oc.cursor.streamSeq = cfg.OptStartSeq - 1 } - return upsertConsumer(ctx, js, stream, cfg) + + return oc, nil } // Consumer returns a hook to an existing consumer, allowing processing of messages diff --git a/jsv2/jetstream/jetstream_test.go b/jetstream/jetstream_test.go similarity index 88% rename from jsv2/jetstream/jetstream_test.go rename to jetstream/jetstream_test.go index 41b7027..57b70c8 100644 --- a/jsv2/jetstream/jetstream_test.go +++ b/jetstream/jetstream_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -663,7 +663,7 @@ func TestStreamNames(t *testing.T) { } } -func TestJetStream_CreateConsumer(t *testing.T) { +func TestJetStream_AddConsumer(t *testing.T) { tests := []struct { name string stream string @@ -684,15 +684,15 @@ func TestJetStream_CreateConsumer(t *testing.T) { shouldCreate: true, }, { - name: "consumer already exists, idempotent operation", + name: "consumer already exists, update", stream: "foo", - consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckExplicitPolicy}, + consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckExplicitPolicy, Description: "test consumer"}, }, { - name: "consumer already exists, config mismatch", + name: "consumer already exists, illegal update", stream: "foo", - consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckExplicitPolicy, Description: "test"}, - withError: ErrConsumerNameAlreadyInUse, + consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckNonePolicy, Description: "test consumer"}, + withError: ErrConsumerCreate, }, { name: "stream does not exist", @@ -735,12 +735,12 @@ func TestJetStream_CreateConsumer(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var sub *nats.Subscription - if test.consumerConfig.Durable != "" { - sub, err = nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.DURABLE.CREATE.foo.%s", test.consumerConfig.Durable)) + if test.consumerConfig.FilterSubject != "" { + sub, err = nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.CREATE.foo.*.%s", test.consumerConfig.FilterSubject)) } else { - sub, err = nc.SubscribeSync("$JS.API.CONSUMER.CREATE.foo") + sub, err = nc.SubscribeSync("$JS.API.CONSUMER.CREATE.foo.*") } - c, err := js.CreateConsumer(ctx, test.stream, test.consumerConfig) + c, err := js.AddConsumer(ctx, test.stream, test.consumerConfig) if test.withError != nil { if err == nil || !errors.Is(err, test.withError) { t.Fatalf("Expected error: %v; got: %v", test.withError, err) @@ -763,91 +763,6 @@ func TestJetStream_CreateConsumer(t *testing.T) { } } -func TestJetStream_UpdateConsumer(t *testing.T) { - tests := []struct { - name string - stream string - durable string - withError error - }{ - { - name: "update consumer", - stream: "foo", - durable: "dur", - }, - { - name: "consumer does not exist", - stream: "foo", - durable: "abc", - withError: ErrConsumerNotFound, - }, - { - name: "invalid durable name", - stream: "foo", - durable: "dur.123", - withError: ErrInvalidConsumerName, - }, - { - name: "stream does not exist", - stream: "abc", - durable: "dur", - withError: ErrStreamNotFound, - }, - { - name: "invalid stream name", - stream: "foo.1", - durable: "dur", - withError: ErrInvalidStreamName, - }, - } - - srv := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, srv) - nc, err := nats.Connect(srv.ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - js, err := New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c, err := js.UpdateConsumer(ctx, test.stream, ConsumerConfig{Durable: test.durable, AckPolicy: AckAllPolicy, Description: test.name}) - if test.withError != nil { - if err == nil || !errors.Is(err, test.withError) { - t.Fatalf("Expected error: %v; got: %v", test.withError, err) - } - return - } - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err = s.Consumer(ctx, c.CachedInfo().Name) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if c.CachedInfo().Config.Description != test.name { - t.Fatalf("Invalid consumer description after update; want: %s; got: %s", test.name, c.CachedInfo().Config.Description) - } - }) - } -} - func TestJetStream_Consumer(t *testing.T) { tests := []struct { name string @@ -905,7 +820,7 @@ func TestJetStream_Consumer(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) + _, err = s.AddConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -986,7 +901,7 @@ func TestJetStream_DeleteConsumer(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) + _, err = s.AddConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/jsv2/jetstream/message.go b/jetstream/message.go similarity index 99% rename from jsv2/jetstream/message.go rename to jetstream/message.go index 192127a..0274983 100644 --- a/jsv2/jetstream/message.go +++ b/jetstream/message.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/jsv2/jetstream/message_test.go b/jetstream/message_test.go similarity index 96% rename from jsv2/jetstream/message_test.go rename to jetstream/message_test.go index c9f874c..f4f87b8 100644 --- a/jsv2/jetstream/message_test.go +++ b/jetstream/message_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -43,7 +43,7 @@ func TestMessageDetails(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{ + c, err := s.AddConsumer(ctx, ConsumerConfig{ Durable: "cons", AckPolicy: AckExplicitPolicy, Description: "test consumer", @@ -59,9 +59,8 @@ func TestMessageDetails(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -164,7 +163,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{ + c, err := s.AddConsumer(ctx, ConsumerConfig{ Durable: "cons", AckPolicy: AckExplicitPolicy, Description: "test consumer", @@ -189,8 +188,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -227,8 +225,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -256,8 +253,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -294,8 +290,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -321,8 +316,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -359,8 +353,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -397,8 +390,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } @@ -435,8 +427,7 @@ func TestAckVariants(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { t.Fatalf("No messages available") } diff --git a/jsv2/jetstream/options.go b/jetstream/options.go similarity index 59% rename from jsv2/jetstream/options.go rename to jetstream/options.go index 8eeb0c4..de99c2f 100644 --- a/jsv2/jetstream/options.go +++ b/jetstream/options.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -18,6 +18,16 @@ import ( "time" ) +type pullOptFunc func(*consumeOpts) error + +func (fn pullOptFunc) configureConsume(opts *consumeOpts) error { + return fn(opts) +} + +func (fn pullOptFunc) configureMessages(opts *consumeOpts) error { + return fn(opts) +} + // WithClientTrace enables request/response API calls tracing // ClientTrace is used to provide handlers for each event func WithClientTrace(ct *ClientTrace) JetStreamOpt { @@ -78,101 +88,126 @@ func WithPurgeKeep(keep uint64) StreamPurgeOpt { } } -// WithConsumeMaxMessages limits the number of messages to be fetched from the stream in one request +// PullMaxMessages limits the number of messages to be fetched from the stream in one request // If not provided, a default of 100 messages will be used -func WithConsumeMaxMessages(maxMessages int) ConsumeOpts { - return func(cfg *consumeOpts) error { - if maxMessages <= 0 { - return fmt.Errorf("%w: maxMessages size must be at least 1", ErrInvalidOption) - } - cfg.MaxMessages = maxMessages - return nil +type PullMaxMessages int + +func (max PullMaxMessages) configureConsume(opts *consumeOpts) error { + if max <= 0 { + return fmt.Errorf("%w: maxMessages size must be at least 1", ErrInvalidOption) } + opts.MaxMessages = int(max) + return nil } -// WithConsumeExpiry sets timeout on a single batch request, waiting until at least one message is available -func WithConsumeExpiry(expires time.Duration) ConsumeOpts { - return func(cfg *consumeOpts) error { - if expires < 0 { - return fmt.Errorf("%w: expires value must be positive", ErrInvalidOption) - } - cfg.Expires = expires - return nil +func (max PullMaxMessages) configureMessages(opts *consumeOpts) error { + if max <= 0 { + return fmt.Errorf("%w: maxMessages size must be at least 1", ErrInvalidOption) } + opts.MaxMessages = int(max) + return nil } -// WithConsumeMaxBytes sets max_bytes limit on a fetch request -func WithConsumeMaxBytes(maxBytes int) ConsumeOpts { - return func(cfg *consumeOpts) error { - cfg.MaxBytes = maxBytes - return nil +// PullExpiry sets timeout on a single batch request, waiting until at least one message is available +type PullExpiry time.Duration + +func (exp PullExpiry) configureConsume(opts *consumeOpts) error { + if exp < 0 { + return fmt.Errorf("%w: expires value must be positive", ErrInvalidOption) } + opts.Expires = time.Duration(exp) + return nil } -// WithMessagesBatchSize limits the number of messages to be fetched from the stream in one request -// If not provided, a default of 100 messages will be used -func WithMessagesBatchSize(maxMessages int) ConsumerMessagesOpts { - return func(opts *consumeOpts) error { - if maxMessages <= 0 { - return fmt.Errorf("%w: batch size must be at least 1", ErrInvalidOption) - } - opts.MaxMessages = maxMessages - return nil +func (exp PullExpiry) configureMessages(opts *consumeOpts) error { + if exp < 0 { + return fmt.Errorf("%w: expires value must be positive", ErrInvalidOption) } + opts.Expires = time.Duration(exp) + return nil } -// WithMessagesHeartbeat sets the idle heartbeat duration for a pull subscription -// If a client does not receive a heartbeat meassage from a stream for more than the idle heartbeat setting, the subscription will be removed and error will be passed to the message handler -func WithMessagesHeartbeat(hb time.Duration) ConsumerMessagesOpts { - return func(opts *consumeOpts) error { - if hb <= 0 { - return fmt.Errorf("%w: idle_heartbeat value must be greater than 0", ErrInvalidOption) - } - opts.Heartbeat = hb - return nil +// PullMaxBytes sets max_bytes limit on a fetch request +type PullMaxBytes int + +func (max PullMaxBytes) configureConsume(opts *consumeOpts) error { + if max <= 0 { + return fmt.Errorf("%w: max bytes must be greater then 0", ErrInvalidOption) } + opts.MaxBytes = int(max) + return nil } -// WithMessagesMaxBytes sets max_bytes limit on a fetch request -func WithMessagesMaxBytes(maxBytes int) ConsumerMessagesOpts { - return func(opts *consumeOpts) error { - opts.MaxBytes = maxBytes - return nil +func (max PullMaxBytes) configureMessages(opts *consumeOpts) error { + if max <= 0 { + return fmt.Errorf("%w: max bytes must be greater then 0", ErrInvalidOption) } + opts.MaxBytes = int(max) + return nil } -// WithMessagesErrHandler sets custom error handler invoked when an error was encountered while consuming messages +// PullThresholdMessages sets the message count on which Consume will trigger +// new pull request to the server. Defaults to 50% of MaxMessages. +type PullThresholdMessages int + +func (t PullThresholdMessages) configureConsume(opts *consumeOpts) error { + opts.ThresholdMessages = int(t) + return nil +} + +// PullThresholdBytes sets the byte count on which Consume will trigger +// new pull request to the server. Defaults to 50% of MaxBytes (if set). +type PullThresholBytes int + +func (t PullThresholBytes) configureConsume(opts *consumeOpts) error { + opts.ThresholdBytes = int(t) + return nil +} + +// PullHeartbeat sets the idle heartbeat duration for a pull subscription +// If a client does not receive a heartbeat message from a stream for more +// than the idle heartbeat setting, the subscription will be removed +// and error will be passed to the message handler +type PullHeartbeat time.Duration + +func (hb PullHeartbeat) configureConsume(opts *consumeOpts) error { + hbTime := time.Duration(hb) + if hbTime < 1*time.Second || hbTime > 30*time.Second { + return fmt.Errorf("%w: idle_heartbeat value must be within 1s-30s range", ErrInvalidOption) + } + opts.Heartbeat = hbTime + return nil +} + +func (hb PullHeartbeat) configureMessages(opts *consumeOpts) error { + hbTime := time.Duration(hb) + if hbTime < 1*time.Second || hbTime > 30*time.Second { + return fmt.Errorf("%w: idle_heartbeat value must be within 1s-30s range", ErrInvalidOption) + } + opts.Heartbeat = hbTime + return nil +} + +// ConsumeErrHandler sets custom error handler invoked when an error was encountered while consuming messages // It will be invoked for both terminal (Consumer Deleted, invalid request body) and non-terminal (e.g. missing heartbeats) errors -func WithMessagesErrHandler(cb ConsumeErrHandler) ConsumerMessagesOpts { - return func(opts *consumeOpts) error { - opts.ErrHandler = cb +func ConsumeErrHandler(cb ConsumeErrHandlerFunc) PullConsumeOpt { + return pullOptFunc(func(cfg *consumeOpts) error { + cfg.ErrHandler = cb return nil - } + }) } -// WithConsumeHeartbeat sets the idle heartbeat duration for a pull subscription -// If a client does not receive a heartbeat meassage from a stream for more than the idle heartbeat setting, the subscription will be removed and error will be passed to the message handler -func WithConsumeHeartbeat(hb time.Duration) ConsumeOpts { - return func(req *consumeOpts) error { - if hb <= 0 { - return fmt.Errorf("%w: idle_heartbeat value must be greater than 0", ErrInvalidOption) - } - req.Heartbeat = hb - return nil - } -} - -// WithConsumeErrHandler sets custom error handler invoked when an error was encountered while consuming messages +// ConsumeErrHandler sets custom error handler invoked when an error was encountered while consuming messages // It will be invoked for both terminal (Consumer Deleted, invalid request body) and non-terminal (e.g. missing heartbeats) errors -func WithConsumeErrHandler(cb ConsumeErrHandler) ConsumeOpts { - return func(opts *consumeOpts) error { - opts.ErrHandler = cb +func WithMessagesErrOnMissingHeartbeat(hbErr bool) PullMessagesOpt { + return pullOptFunc(func(cfg *consumeOpts) error { + cfg.ReportMissingHeartbeats = hbErr return nil - } + }) } -// WithFetchTimeout sets custom timeout fir fetching predefined batch of messages -func WithFetchTimeout(timeout time.Duration) FetchOpt { +// FetchMaxWait sets custom timeout fir fetching predefined batch of messages +func FetchMaxWait(timeout time.Duration) FetchOpt { return func(req *pullRequest) error { if timeout <= 0 { return fmt.Errorf("%w: timeout value must be greater than 0", ErrInvalidOption) diff --git a/jetstream/ordered.go b/jetstream/ordered.go new file mode 100644 index 0000000..538f5de --- /dev/null +++ b/jetstream/ordered.go @@ -0,0 +1,436 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jetstream + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync/atomic" + "time" + + "github.com/nats-io/nats.go" +) + +type ( + orderedConsumer struct { + jetStream *jetStream + cfg *OrderedConsumerConfig + stream string + currentConsumer *pullConsumer + cursor cursor + namePrefix string + serial int + consumerType consumerType + doReset chan struct{} + resetInProgress uint32 + userErrHandler ConsumeErrHandlerFunc + runningFetch *fetchResult + } + + orderedSubscription struct { + consumer *orderedConsumer + opts []PullMessagesOpt + done chan struct{} + } + + cursor struct { + streamSeq uint64 + deliverSeq uint64 + } + + consumerType int +) + +const ( + consumerTypeNotSet consumerType = iota + consumerTypeConsume + consumerTypeFetch +) + +// Consume can be used to continuously receive messages and handle them with the provided callback function +func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (ConsumeContext, error) { + if c.consumerType == consumerTypeNotSet || c.consumerType == consumerTypeConsume && c.currentConsumer == nil { + c.consumerType = consumerTypeConsume + err := c.reset() + if err != nil { + return nil, err + } + } + if c.consumerType == consumerTypeFetch { + return nil, fmt.Errorf("ordered consumer initialized as fetch") + } + consumeOpts, err := parseConsumeOpts(opts...) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) + } + c.userErrHandler = consumeOpts.ErrHandler + opts = append(opts, ConsumeErrHandler(c.errHandler(c.serial))) + internalHandler := func(serial int) func(msg Msg) { + return func(msg Msg) { + // handler is a noop if message was delivered for a consumer with different serial + if serial != c.serial { + return + } + meta, err := msg.Metadata() + if err != nil { + c.errHandler(serial)(c.currentConsumer.subscriptions[""], err) + return + } + dseq := meta.Sequence.Consumer + if dseq != c.cursor.deliverSeq+1 { + c.errHandler(serial)(c.currentConsumer.subscriptions[""], ErrOrderedSequenceMismatch) + return + } + c.cursor.deliverSeq = dseq + c.cursor.streamSeq = meta.Sequence.Stream + handler(msg) + } + } + + _, err = c.currentConsumer.Consume(internalHandler(c.serial), opts...) + if err != nil { + return nil, err + } + + sub := &orderedSubscription{ + consumer: c, + done: make(chan struct{}, 1), + } + go func() { + for { + select { + case <-c.doReset: + if err := c.reset(); err != nil { + c.errHandler(c.serial)(c.currentConsumer.subscriptions[""], err) + } + // overwrite the previous err handler to use the new serial + opts[len(opts)-1] = ConsumeErrHandler(c.errHandler(c.serial)) + if _, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil { + c.errHandler(c.serial)(c.currentConsumer.subscriptions[""], err) + } + case <-sub.done: + return + } + } + }() + return sub, nil +} + +func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err error) { + return func(cc ConsumeContext, err error) { + if c.userErrHandler != nil { + c.userErrHandler(cc, err) + } + if errors.Is(err, ErrNoHeartbeat) || + errors.Is(err, ErrOrderedSequenceMismatch) || + errors.Is(err, ErrConsumerDeleted) { + // only reset if serial matches the currect consumer serial and there is no reset in progress + if serial == c.serial && atomic.LoadUint32(&c.resetInProgress) == 0 { + atomic.StoreUint32(&c.resetInProgress, 1) + c.doReset <- struct{}{} + } + + } + } +} + +// Messages returns [MessagesContext], allowing continuously iterating over messages on a stream. +func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error) { + if c.consumerType == consumerTypeNotSet { + c.consumerType = consumerTypeConsume + err := c.reset() + if err != nil { + return nil, err + } + } + if c.consumerType == consumerTypeFetch { + return nil, fmt.Errorf("ordered consumer initialized as fetch") + } + consumeOpts, err := parseMessagesOpts(opts...) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) + } + c.userErrHandler = consumeOpts.ErrHandler + opts = append(opts, WithMessagesErrOnMissingHeartbeat(true)) + _, err = c.currentConsumer.Messages(opts...) + if err != nil { + return nil, err + } + + sub := &orderedSubscription{ + consumer: c, + opts: opts, + done: make(chan struct{}, 1), + } + + return sub, nil +} + +func (s *orderedSubscription) Next() (Msg, error) { + next := func() (Msg, error) { + for { + currentConsumer := s.consumer.currentConsumer + msg, err := currentConsumer.subscriptions[""].Next() + if err != nil { + if err := s.consumer.reset(); err != nil { + return nil, err + } + _, err := s.consumer.currentConsumer.Messages(s.opts...) + if err != nil { + return nil, err + } + continue + } + meta, err := msg.Metadata() + if err != nil { + s.consumer.errHandler(s.consumer.serial)(currentConsumer.subscriptions[""], err) + continue + } + serial := serialNumberFromConsumer(meta.Consumer) + dseq := meta.Sequence.Consumer + if dseq != s.consumer.cursor.deliverSeq+1 { + s.consumer.errHandler(serial)(currentConsumer.subscriptions[""], ErrOrderedSequenceMismatch) + continue + } + s.consumer.cursor.deliverSeq = dseq + s.consumer.cursor.streamSeq = meta.Sequence.Stream + return msg, nil + } + } + return next() +} + +func (s *orderedSubscription) Stop() { + if s.consumer.currentConsumer == nil || s.consumer.currentConsumer.subscriptions[""] == nil { + return + } + s.consumer.currentConsumer.subscriptions[""].Stop() + close(s.done) +} + +// Fetch is used to retrieve up to a provided number of messages from a stream. +// This method will always send a single request and wait until either all messages are retreived +// or context reaches its deadline. +func (c *orderedConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) { + if c.consumerType == consumerTypeConsume { + return nil, fmt.Errorf("ordered consumer initialized as consume") + } + if c.runningFetch != nil { + if !c.runningFetch.done { + return nil, fmt.Errorf("cannot run concurrent ordered Fetch requests") + } + c.cursor.streamSeq = c.runningFetch.sseq + } + c.consumerType = consumerTypeFetch + err := c.reset() + if err != nil { + return nil, err + } + msgs, err := c.currentConsumer.Fetch(batch, opts...) + if err != nil { + return nil, err + } + c.runningFetch = msgs.(*fetchResult) + return msgs, nil +} + +// FetchBytes is used to retrieve up to a provided bytes from the stream. +// This method will always send a single request and wait until provided number of bytes is +// exceeded or request times out. +func (c *orderedConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, error) { + if c.consumerType == consumerTypeConsume { + return nil, fmt.Errorf("ordered consumer initialized as consume") + } + if c.runningFetch != nil { + if !c.runningFetch.done { + return nil, fmt.Errorf("cannot run concurrent ordered Fetch requests") + } + c.cursor.streamSeq = c.runningFetch.sseq + } + c.consumerType = consumerTypeFetch + err := c.reset() + if err != nil { + return nil, err + } + msgs, err := c.currentConsumer.FetchBytes(maxBytes, opts...) + if err != nil { + return nil, err + } + c.runningFetch = msgs.(*fetchResult) + return msgs, nil +} + +// FetchNoWait is used to retrieve up to a provided number of messages from a stream. +// This method will always send a single request and immediately return up to a provided number of messages +func (c *orderedConsumer) FetchNoWait(batch int) (MessageBatch, error) { + if c.consumerType == consumerTypeConsume { + return nil, fmt.Errorf("ordered consumer initialized as consume") + } + if c.runningFetch != nil && !c.runningFetch.done { + return nil, fmt.Errorf("cannot run concurrent ordered Fetch requests") + } + c.consumerType = consumerTypeFetch + err := c.reset() + if err != nil { + return nil, err + } + return c.currentConsumer.FetchNoWait(batch) +} + +func (c *orderedConsumer) Next(opts ...FetchOpt) (Msg, error) { + res, err := c.Fetch(1, opts...) + if err != nil { + return nil, err + } + msg := <-res.Messages() + if msg != nil { + return msg, nil + } + return nil, res.Error() +} + +func serialNumberFromConsumer(name string) int { + if len(name) == 0 { + return 0 + } + serial, err := strconv.Atoi(name[len(name)-1:]) + if err != nil { + return 0 + } + return serial +} + +func (c *orderedConsumer) reset() error { + defer atomic.StoreUint32(&c.resetInProgress, 0) + if c.currentConsumer != nil { + // c.currentConsumer.subscription.Stop() + var err error + for i := 0; ; i++ { + if c.cfg.MaxResetAttempts > 0 && i == c.cfg.MaxResetAttempts { + return fmt.Errorf("%w: maximum number of delete attempts reached: %s", ErrOrderedConsumerReset, err) + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + err = c.jetStream.DeleteConsumer(ctx, c.stream, c.currentConsumer.CachedInfo().Name) + if err != nil { + if errors.Is(err, ErrConsumerNotFound) { + cancel() + break + } + if errors.Is(err, nats.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { + cancel() + continue + } + cancel() + return err + } + cancel() + break + } + } + seq := c.cursor.streamSeq + 1 + c.cursor.deliverSeq = 0 + consumerConfig := c.getConsumerConfigForSeq(seq) + + var err error + var cons Consumer + for i := 0; ; i++ { + if c.cfg.MaxResetAttempts > 0 && i == c.cfg.MaxResetAttempts { + return fmt.Errorf("%w: maximum number of create consumer attempts reached: %s", ErrOrderedConsumerReset, err) + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + cons, err = c.jetStream.AddConsumer(ctx, c.stream, *consumerConfig) + if err != nil { + if errors.Is(err, ErrConsumerNotFound) { + cancel() + break + } + if errors.Is(err, nats.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { + cancel() + continue + } + cancel() + return err + } + cancel() + break + } + c.currentConsumer = cons.(*pullConsumer) + return nil +} + +func (c *orderedConsumer) getConsumerConfigForSeq(seq uint64) *ConsumerConfig { + c.serial++ + name := fmt.Sprintf("%s_%d", c.namePrefix, c.serial) + cfg := &ConsumerConfig{ + Name: name, + DeliverPolicy: DeliverByStartSequencePolicy, + OptStartSeq: seq, + AckPolicy: AckNonePolicy, + InactiveThreshold: 5 * time.Minute, + Replicas: 1, + FilterSubjects: c.cfg.FilterSubjects, + } + + if seq != c.cfg.OptStartSeq+1 { + return cfg + } + + // initial request, some options may be modified at that point + cfg.DeliverPolicy = c.cfg.DeliverPolicy + if c.cfg.DeliverPolicy == DeliverLastPerSubjectPolicy || + c.cfg.DeliverPolicy == DeliverLastPolicy || + c.cfg.DeliverPolicy == DeliverNewPolicy || + c.cfg.DeliverPolicy == DeliverAllPolicy { + + cfg.OptStartSeq = 0 + } + + if cfg.DeliverPolicy == DeliverLastPerSubjectPolicy && len(c.cfg.FilterSubjects) == 0 { + cfg.FilterSubjects = []string{">"} + } + if c.cfg.OptStartTime != nil { + cfg.OptStartSeq = 0 + cfg.DeliverPolicy = DeliverByStartTimePolicy + cfg.OptStartTime = c.cfg.OptStartTime + } + if c.cfg.InactiveThreshold != 0 { + cfg.InactiveThreshold = c.cfg.InactiveThreshold + } + + return cfg +} + +func (c *orderedConsumer) Info(ctx context.Context) (*ConsumerInfo, error) { + infoSubject := apiSubj(c.jetStream.apiPrefix, fmt.Sprintf(apiConsumerInfoT, c.stream, c.currentConsumer.name)) + var resp consumerInfoResponse + + if _, err := c.jetStream.apiRequestJSON(ctx, infoSubject, &resp); err != nil { + return nil, err + } + if resp.Error != nil { + if resp.Error.ErrorCode == JSErrCodeConsumerNotFound { + return nil, ErrConsumerNotFound + } + return nil, resp.Error + } + + c.currentConsumer.info = resp.ConsumerInfo + return resp.ConsumerInfo, nil +} + +func (c *orderedConsumer) CachedInfo() *ConsumerInfo { + return c.currentConsumer.info +} diff --git a/jetstream/ordered_test.go b/jetstream/ordered_test.go new file mode 100644 index 0000000..04c2860 --- /dev/null +++ b/jetstream/ordered_test.go @@ -0,0 +1,191 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jetstream + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/nats-io/nats.go" +) + +func TestOrderedConsumerConsume(t *testing.T) { + testSubject := "FOO.123" + testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} + publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + for _, msg := range testMsgs { + if err := nc.Publish(testSubject, []byte(msg)); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + } + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs := make([]Msg, 0) + wg := &sync.WaitGroup{} + wg.Add(len(testMsgs)) + l, err := c.Consume(func(msg Msg) { + msgs = append(msgs, msg) + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc) + wg.Wait() + + name := c.CachedInfo().Name + if err := s.DeleteConsumer(ctx, name); err != nil { + t.Fatal(err) + } + wg.Add(len(testMsgs)) + publishTestMsgs(t, nc) + wg.Wait() + + l.Stop() + time.Sleep(10 * time.Millisecond) + publishTestMsgs(t, nc) + wg.Add(len(testMsgs)) + l, err = c.Consume(func(msg Msg) { + msgs = append(msgs, msg) + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer l.Stop() + wg.Wait() + if len(msgs) != 3*len(testMsgs) { + t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) + } + for i, msg := range msgs { + if string(msg.Data()) != testMsgs[i%5] { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + } + } +} + +func TestOrderedConsumerMessages(t *testing.T) { + testSubject := "FOO.123" + testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} + publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + for _, msg := range testMsgs { + if err := nc.Publish(testSubject, []byte(msg)); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + } + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs := make([]Msg, 0) + it, err := c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, nc) + for i := 0; i < 5; i++ { + msg, err := it.Next() + if err != nil { + t.Fatal(err) + } + msg.Ack() + msgs = append(msgs, msg) + } + + name := c.CachedInfo().Name + if err := s.DeleteConsumer(ctx, name); err != nil { + t.Fatal(err) + } + publishTestMsgs(t, nc) + for i := 0; i < 5; i++ { + msg, err := it.Next() + if err != nil { + t.Fatal(err) + } + msg.Ack() + msgs = append(msgs, msg) + } + + it.Stop() + time.Sleep(10 * time.Millisecond) + it, err = c.Messages() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, nc) + for i := 0; i < 5; i++ { + msg, err := it.Next() + if err != nil { + t.Fatal(err) + } + msg.Ack() + msgs = append(msgs, msg) + } + + if len(msgs) != 3*len(testMsgs) { + t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) + } + for i, msg := range msgs { + if string(msg.Data()) != testMsgs[i%5] { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + } + } +} diff --git a/jsv2/jetstream/publish.go b/jetstream/publish.go similarity index 99% rename from jsv2/jetstream/publish.go rename to jetstream/publish.go index 62bd1e0..716852b 100644 --- a/jsv2/jetstream/publish.go +++ b/jetstream/publish.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/jsv2/jetstream/publish_test.go b/jetstream/publish_test.go similarity index 99% rename from jsv2/jetstream/publish_test.go rename to jetstream/publish_test.go index 7014783..45a7319 100644 --- a/jsv2/jetstream/publish_test.go +++ b/jetstream/publish_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/jetstream/pull.go b/jetstream/pull.go new file mode 100644 index 0000000..27891fa --- /dev/null +++ b/jetstream/pull.go @@ -0,0 +1,822 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jetstream + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math" + "sync" + "sync/atomic" + "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nuid" +) + +type ( + // MessagesContext supports iterating over a messages on a stream. + MessagesContext interface { + // Next retreives nest message on a stream. It will block until the next message is available. + Next() (Msg, error) + // Stop closes the iterator and cancels subscription. + Stop() + } + + ConsumeContext interface { + Stop() + } + + // MessageHandler is a handler function used as callback in [Consume] + MessageHandler func(msg Msg) + + // PullConsumeOpt represent additional options used in [Consume] for pull consumers + PullConsumeOpt interface { + configureConsume(*consumeOpts) error + } + + // PullMessagesOpt represent additional options used in [Messages] for pull consumers + PullMessagesOpt interface { + configureMessages(*consumeOpts) error + } + + pullConsumer struct { + sync.Mutex + jetStream *jetStream + stream string + durable bool + name string + info *ConsumerInfo + subscriptions map[string]*pullSubscription + } + + pullRequest struct { + Expires time.Duration `json:"expires,omitempty"` + Batch int `json:"batch,omitempty"` + MaxBytes int `json:"max_bytes,omitempty"` + NoWait bool `json:"no_wait,omitempty"` + Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` + } + + consumeOpts struct { + Expires time.Duration + MaxMessages int + MaxBytes int + Heartbeat time.Duration + ErrHandler ConsumeErrHandlerFunc + ReportMissingHeartbeats bool + ThresholdMessages int + ThresholdBytes int + } + + ConsumeErrHandlerFunc func(consumeCtx ConsumeContext, err error) + + pullSubscription struct { + sync.Mutex + id string + consumer *pullConsumer + subscription *nats.Subscription + msgs chan *nats.Msg + errs chan error + pending pendingMsgs + hbMonitor *hbMonitor + fetchInProgress uint32 + closed uint32 + done chan struct{} + connected chan struct{} + disconnected chan struct{} + fetchNext chan *pullRequest + consumeOpts *consumeOpts + } + + pendingMsgs struct { + msgCount int + byteCount int + } + + MessageBatch interface { + Messages() <-chan Msg + Error() error + } + + fetchResult struct { + msgs chan Msg + err error + done bool + sseq uint64 + } + + FetchOpt func(*pullRequest) error + + hbMonitor struct { + timer *time.Timer + sync.Mutex + } +) + +const ( + DefaultMaxMessages = 500 + DefaultExpires = 30 * time.Second + DefaultHeartbeat = 5 * time.Second + unset = -1 +) + +// Consume returns a ConsumeContext, allowing for processing incoming messages from a stream in a given callback function. +// +// Available options: +// [ConsumeMaxMessages] - sets maximum number of messages stored in a buffer, default is set to 100 +// [ConsumeMaxBytes] - sets maximum number of bytes stored in a buffer +// [ConsumeExpiry] - sets a timeout for individual batch request, default is set to 30 seconds +// [ConsumeHeartbeat] - sets an idle heartbeat setting for a pull request, default is set to 5s +// [ConsumeErrHandler] - sets custom consume error callback handler +// [ConsumeThresholdMessages] - sets the byte count on which Consume will trigger new pull request to the server +// [ConsumeThresholdBytes] - sets the message count on which Consume will trigger new pull request to the server +func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (ConsumeContext, error) { + if handler == nil { + return nil, ErrHandlerRequired + } + consumeOpts, err := parseConsumeOpts(opts...) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) + } + p.Lock() + + subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) + + // for single consume, use empty string as id + // this is useful for ordered consumer, where only a single subscription is valid + var consumeID string + if len(p.subscriptions) > 0 { + consumeID = nuid.Next() + } + sub := &pullSubscription{ + id: consumeID, + consumer: p, + errs: make(chan error, 1), + done: make(chan struct{}, 1), + fetchNext: make(chan *pullRequest, 1), + connected: make(chan struct{}), + disconnected: make(chan struct{}), + consumeOpts: consumeOpts, + } + p.jetStream.conn.RegisterStatusChangeListener(nats.CONNECTED, sub.connected) + p.jetStream.conn.RegisterStatusChangeListener(nats.DISCONNECTED, sub.disconnected) + p.jetStream.conn.RegisterStatusChangeListener(nats.RECONNECTING, sub.disconnected) + + sub.hbMonitor = sub.scheduleHeartbeatCheck(consumeOpts.Heartbeat) + + p.subscriptions[sub.id] = sub + p.Unlock() + + internalHandler := func(msg *nats.Msg) { + if sub.hbMonitor != nil { + sub.hbMonitor.Reset(2 * consumeOpts.Heartbeat) + } + userMsg, msgErr := checkMsg(msg) + if !userMsg && msgErr == nil { + return + } + defer func() { + if sub.pending.msgCount < consumeOpts.ThresholdMessages || + (sub.pending.byteCount < consumeOpts.ThresholdBytes && sub.consumeOpts.MaxBytes != 0) && + atomic.LoadUint32(&sub.fetchInProgress) == 1 { + + sub.fetchNext <- &pullRequest{ + Expires: sub.consumeOpts.Expires, + Batch: sub.consumeOpts.MaxMessages - sub.pending.msgCount, + MaxBytes: sub.consumeOpts.MaxBytes - sub.pending.byteCount, + Heartbeat: sub.consumeOpts.Heartbeat, + } + sub.resetPendingMsgs() + } + }() + if !userMsg { + // heartbeat message + if msgErr == nil { + return + } + if err := sub.handleStatusMsg(msg, msgErr); err != nil { + if atomic.LoadUint32(&sub.closed) == 1 { + return + } + if sub.consumeOpts.ErrHandler != nil { + sub.consumeOpts.ErrHandler(sub, err) + } + sub.Stop() + } + return + } + handler(p.jetStream.toJSMsg(msg)) + sub.decrementPendingMsgs(msg) + } + inbox := nats.NewInbox() + sub.subscription, err = p.jetStream.conn.Subscribe(inbox, internalHandler) + if err != nil { + return nil, err + } + + // initial pull + sub.resetPendingMsgs() + if err := sub.pull(&pullRequest{ + Expires: consumeOpts.Expires, + Batch: consumeOpts.MaxMessages, + MaxBytes: consumeOpts.MaxBytes, + Heartbeat: consumeOpts.Heartbeat, + }, subject); err != nil { + sub.errs <- err + } + + go func() { + isConnected := true + for { + if atomic.LoadUint32(&sub.closed) == 1 { + return + } + select { + case <-sub.disconnected: + if sub.hbMonitor != nil { + sub.hbMonitor.Stop() + } + isConnected = false + case <-sub.connected: + if !isConnected { + // try fetching consumer info several times to make sure consumer is available after reconnect + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _, err := p.Info(ctx) + cancel() + if err == nil { + break + } + if err != nil { + if i == 4 { + sub.cleanupSubscriptionAndRestoreConnHandler() + if sub.consumeOpts.ErrHandler != nil { + sub.consumeOpts.ErrHandler(sub, err) + } + return + } + } + time.Sleep(5 * time.Second) + } + + sub.fetchNext <- &pullRequest{ + Expires: sub.consumeOpts.Expires, + Batch: sub.consumeOpts.MaxMessages, + MaxBytes: sub.consumeOpts.MaxBytes, + Heartbeat: sub.consumeOpts.Heartbeat, + } + sub.resetPendingMsgs() + isConnected = true + } + case err := <-sub.errs: + if sub.consumeOpts.ErrHandler != nil { + sub.consumeOpts.ErrHandler(sub, err) + } + if errors.Is(err, ErrNoHeartbeat) { + sub.fetchNext <- &pullRequest{ + Expires: sub.consumeOpts.Expires, + Batch: sub.consumeOpts.MaxMessages, + MaxBytes: sub.consumeOpts.MaxBytes, + Heartbeat: sub.consumeOpts.Heartbeat, + } + sub.resetPendingMsgs() + } + } + } + }() + + go sub.pullMessages(subject) + + return sub, nil +} + +func (s *pullSubscription) resetPendingMsgs() { + s.Lock() + defer s.Unlock() + s.pending.msgCount = s.consumeOpts.MaxMessages + s.pending.byteCount = s.consumeOpts.MaxBytes +} + +func (s *pullSubscription) decrementPendingMsgs(msg *nats.Msg) { + s.Lock() + defer s.Unlock() + s.pending.msgCount-- + if s.consumeOpts.MaxBytes != 0 { + s.pending.byteCount -= msgSize(msg) + } +} + +// Messages returns MessagesContext, allowing continuously iterating over messages on a stream. +// +// Available options: +// [ConsumeMaxMessages] - sets maximum number of messages stored in a buffer, default is set to 100 +// [ConsumeMaxBytes] - sets maximum number of bytes stored in a buffer +// [ConsumeExpiry] - sets a timeout for individual batch request, default is set to 30 seconds +// [ConsumeHeartbeat] - sets an idle heartbeat setting for a pull request, default is set to 5s +// [ConsumeErrHandler] - sets custom consume error callback handler +// [ConsumeThresholdMessages] - sets the byte count on which Consume will trigger new pull request to the server +// [ConsumeThresholdBytes] - sets the message count on which Consume will trigger new pull request to the server +func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error) { + consumeOpts, err := parseMessagesOpts(opts...) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) + } + + p.Lock() + subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) + + msgs := make(chan *nats.Msg, consumeOpts.MaxMessages) + + // for single consume, use empty string as id + // this is useful for ordered consumer, where only a single subscription is valid + var consumeID string + if len(p.subscriptions) > 0 { + consumeID = nuid.Next() + } + sub := &pullSubscription{ + id: consumeID, + consumer: p, + done: make(chan struct{}, 1), + msgs: msgs, + errs: make(chan error, 1), + fetchNext: make(chan *pullRequest, 1), + connected: make(chan struct{}), + disconnected: make(chan struct{}), + consumeOpts: consumeOpts, + } + p.jetStream.conn.RegisterStatusChangeListener(nats.CONNECTED, sub.connected) + p.jetStream.conn.RegisterStatusChangeListener(nats.DISCONNECTED, sub.disconnected) + p.jetStream.conn.RegisterStatusChangeListener(nats.RECONNECTING, sub.disconnected) + inbox := nats.NewInbox() + sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs) + if err != nil { + p.Unlock() + return nil, err + } + + go func() { + <-sub.done + sub.cleanupSubscriptionAndRestoreConnHandler() + }() + p.subscriptions[sub.id] = sub + p.Unlock() + + go sub.pullMessages(subject) + + return sub, nil +} + +func (s *pullSubscription) Next() (Msg, error) { + s.Lock() + defer s.Unlock() + if atomic.LoadUint32(&s.closed) == 1 { + return nil, ErrMsgIteratorClosed + } + hbMonitor := s.scheduleHeartbeatCheck(s.consumeOpts.Heartbeat) + defer func() { + if hbMonitor != nil { + hbMonitor.Stop() + } + }() + + isConnected := true + for { + if s.pending.msgCount < s.consumeOpts.ThresholdMessages || + (s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0) && + atomic.LoadUint32(&s.fetchInProgress) == 1 { + + s.fetchNext <- &pullRequest{ + Expires: s.consumeOpts.Expires, + Batch: s.consumeOpts.MaxMessages - s.pending.msgCount, + MaxBytes: s.consumeOpts.MaxBytes - s.pending.byteCount, + Heartbeat: s.consumeOpts.Heartbeat, + } + s.pending.msgCount = s.consumeOpts.MaxMessages + if s.consumeOpts.MaxBytes > 0 { + s.pending.byteCount = s.consumeOpts.MaxBytes + } + } + select { + case msg := <-s.msgs: + if hbMonitor != nil { + hbMonitor.Reset(2 * s.consumeOpts.Heartbeat) + } + userMsg, msgErr := checkMsg(msg) + if !userMsg { + // heartbeat message + if msgErr == nil { + continue + } + if err := s.handleStatusMsg(msg, msgErr); err != nil { + s.Stop() + return nil, err + } + continue + } + s.pending.msgCount-- + if s.consumeOpts.MaxBytes > 0 { + s.pending.byteCount -= msgSize(msg) + } + return s.consumer.jetStream.toJSMsg(msg), nil + case <-s.disconnected: + if hbMonitor != nil { + hbMonitor.Stop() + } + isConnected = false + case <-s.connected: + if !isConnected { + // try fetching consumer info several times to make sure consumer is available after reconnect + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _, err := s.consumer.Info(ctx) + cancel() + if err == nil { + break + } + if err != nil { + if i == 4 { + s.Stop() + return nil, err + } + } + time.Sleep(5 * time.Second) + } + s.pending.msgCount = 0 + s.pending.byteCount = 0 + hbMonitor = s.scheduleHeartbeatCheck(s.consumeOpts.Heartbeat) + } + case err := <-s.errs: + if errors.Is(err, ErrNoHeartbeat) { + s.pending.msgCount = 0 + s.pending.byteCount = 0 + if s.consumeOpts.ReportMissingHeartbeats { + return nil, err + } + } + } + } +} + +func (s *pullSubscription) handleStatusMsg(msg *nats.Msg, msgErr error) error { + if !errors.Is(msgErr, nats.ErrTimeout) && !errors.Is(msgErr, ErrMaxBytesExceeded) { + if s.consumeOpts.ErrHandler != nil { + s.consumeOpts.ErrHandler(s, msgErr) + } + if errors.Is(msgErr, ErrConsumerDeleted) || errors.Is(msgErr, ErrBadRequest) { + return msgErr + } + if errors.Is(msgErr, ErrConsumerLeadershipChanged) { + s.pending.msgCount = 0 + s.pending.byteCount = 0 + } + return nil + } + msgsLeft, bytesLeft, err := parsePending(msg) + if err != nil { + if s.consumeOpts.ErrHandler != nil { + s.consumeOpts.ErrHandler(s, err) + } + } + s.pending.msgCount -= msgsLeft + if s.pending.msgCount < 0 { + s.pending.msgCount = 0 + } + if s.consumeOpts.MaxBytes > 0 { + s.pending.byteCount -= bytesLeft + if s.pending.byteCount < 0 { + s.pending.byteCount = 0 + } + } + return nil +} + +func (hb *hbMonitor) Stop() { + hb.Mutex.Lock() + hb.timer.Stop() + hb.Mutex.Unlock() +} + +func (hb *hbMonitor) Reset(dur time.Duration) { + hb.Mutex.Lock() + hb.timer.Reset(dur) + hb.Mutex.Unlock() +} + +func (s *pullSubscription) Stop() { + if atomic.LoadUint32(&s.closed) == 1 { + return + } + close(s.done) + atomic.StoreUint32(&s.closed, 1) +} + +// Fetch sends a single request to retrieve given number of messages. +// It will wait up to provided expiry time if not all messages are available. +func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) { + req := &pullRequest{ + Batch: batch, + Expires: DefaultExpires, + } + for _, opt := range opts { + if err := opt(req); err != nil { + return nil, err + } + } + // for longer pulls, set heartbeat value + if req.Expires >= 10*time.Second { + req.Heartbeat = 5 * time.Second + } + + return p.fetch(req) + +} + +// FetchBytes is used to retrieve up to a provided bytes from the stream. +// This method will always send a single request and wait until provided number of bytes is +// exceeded or request times out. +func (p *pullConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBatch, error) { + req := &pullRequest{ + Batch: 1000000, + MaxBytes: maxBytes, + Expires: DefaultExpires, + } + for _, opt := range opts { + if err := opt(req); err != nil { + return nil, err + } + } + // for longer pulls, set heartbeat value + if req.Expires >= 10*time.Second { + req.Heartbeat = 5 * time.Second + } + + return p.fetch(req) +} + +// Fetch sends a single request to retrieve given number of messages. +// If there are any messages available at the time of sending request, +// FetchNoWait will return immediately. +func (p *pullConsumer) FetchNoWait(batch int) (MessageBatch, error) { + req := &pullRequest{ + Batch: batch, + NoWait: true, + } + + return p.fetch(req) +} + +func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { + res := &fetchResult{ + msgs: make(chan Msg, req.Batch), + } + msgs := make(chan *nats.Msg, 2*req.Batch) + subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) + + sub := &pullSubscription{ + consumer: p, + done: make(chan struct{}, 1), + msgs: msgs, + errs: make(chan error, 1), + } + inbox := nats.NewInbox() + var err error + sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs) + if err != nil { + return nil, err + } + if err := sub.pull(req, subject); err != nil { + return nil, err + } + + var receivedMsgs, receivedBytes int + hbTimer := sub.scheduleHeartbeatCheck(req.Heartbeat) + go func(res *fetchResult) { + defer sub.subscription.Unsubscribe() + defer close(res.msgs) + for { + if receivedMsgs == req.Batch || (req.MaxBytes != 0 && receivedBytes == req.MaxBytes) { + res.done = true + return + } + select { + case msg := <-msgs: + if hbTimer != nil { + hbTimer.Reset(2 * req.Heartbeat) + } + userMsg, err := checkMsg(msg) + if err != nil { + if !errors.Is(err, nats.ErrTimeout) && !errors.Is(err, ErrNoMessages) && !errors.Is(err, ErrMaxBytesExceeded) { + res.err = err + } + res.done = true + return + } + if !userMsg { + continue + } + res.msgs <- p.jetStream.toJSMsg(msg) + meta, err := msg.Metadata() + if err != nil { + res.err = fmt.Errorf("parsing message metadata: %s", err) + } + res.sseq = meta.Sequence.Stream + receivedMsgs++ + if req.MaxBytes != 0 { + receivedBytes += msgSize(msg) + } + case <-time.After(req.Expires + 1*time.Second): + res.err = fmt.Errorf("fetch timed out") + res.done = true + return + } + } + }(res) + return res, nil +} + +func (fr *fetchResult) Messages() <-chan Msg { + return fr.msgs +} + +func (fr *fetchResult) Error() error { + return fr.err +} + +func (p *pullConsumer) Next(opts ...FetchOpt) (Msg, error) { + res, err := p.Fetch(1, opts...) + if err != nil { + return nil, err + } + msg := <-res.Messages() + if msg != nil { + return msg, nil + } + return nil, res.Error() +} + +func (s *pullSubscription) pullMessages(subject string) { + for { + select { + case req := <-s.fetchNext: + atomic.StoreUint32(&s.fetchInProgress, 1) + + if err := s.pull(req, subject); err != nil { + if errors.Is(err, ErrMsgIteratorClosed) { + s.cleanupSubscriptionAndRestoreConnHandler() + return + } + s.errs <- err + } + atomic.StoreUint32(&s.fetchInProgress, 0) + case <-s.done: + s.cleanupSubscriptionAndRestoreConnHandler() + return + } + } +} + +func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor { + if dur == 0 { + return nil + } + return &hbMonitor{ + timer: time.AfterFunc(2*dur, func() { + s.errs <- ErrNoHeartbeat + }), + } +} + +func (s *pullSubscription) cleanupSubscriptionAndRestoreConnHandler() { + s.consumer.Lock() + defer s.consumer.Unlock() + if s.subscription == nil { + return + } + if s.hbMonitor != nil { + s.hbMonitor.Stop() + } + s.subscription.Unsubscribe() + close(s.connected) + close(s.disconnected) + s.subscription = nil + delete(s.consumer.subscriptions, s.id) +} + +func msgSize(msg *nats.Msg) int { + if msg == nil { + return 0 + } + size := len(msg.Subject) + len(msg.Reply) + len(msg.Data) + return size +} + +// pull sends a pull request to the server and waits for messages using a subscription from [pullSubscription]. +// Messages will be fetched up to given batch_size or until there are no more messages or timeout is returned +func (s *pullSubscription) pull(req *pullRequest, subject string) error { + s.consumer.Lock() + defer s.consumer.Unlock() + if atomic.LoadUint32(&s.closed) == 1 { + return ErrMsgIteratorClosed + } + if req.Batch < 1 { + return fmt.Errorf("%w: batch size must be at least 1", nats.ErrInvalidArg) + } + reqJSON, err := json.Marshal(req) + if err != nil { + return err + } + + reply := s.subscription.Subject + if err := s.consumer.jetStream.conn.PublishRequest(subject, reply, reqJSON); err != nil { + return err + } + return nil +} + +func parseConsumeOpts(opts ...PullConsumeOpt) (*consumeOpts, error) { + consumeOpts := &consumeOpts{ + MaxMessages: unset, + MaxBytes: unset, + Expires: DefaultExpires, + Heartbeat: unset, + ReportMissingHeartbeats: true, + } + for _, opt := range opts { + if err := opt.configureConsume(consumeOpts); err != nil { + return nil, err + } + } + if err := consumeOpts.setDefaults(); err != nil { + return nil, err + } + return consumeOpts, nil +} + +func parseMessagesOpts(opts ...PullMessagesOpt) (*consumeOpts, error) { + consumeOpts := &consumeOpts{ + MaxMessages: unset, + MaxBytes: unset, + Expires: DefaultExpires, + Heartbeat: unset, + ReportMissingHeartbeats: true, + } + for _, opt := range opts { + if err := opt.configureMessages(consumeOpts); err != nil { + return nil, err + } + } + if err := consumeOpts.setDefaults(); err != nil { + return nil, err + } + return consumeOpts, nil +} + +func (consumeOpts *consumeOpts) setDefaults() error { + if consumeOpts.MaxBytes != unset && consumeOpts.MaxMessages != unset { + return fmt.Errorf("only one of MaxMessages and MaxBytes can be specified") + } + if consumeOpts.MaxBytes != unset { + // when max_bytes is used, set batch size to a very large number + consumeOpts.MaxMessages = 1000000 + } else if consumeOpts.MaxMessages != unset { + consumeOpts.MaxBytes = 0 + } else { + if consumeOpts.MaxBytes == unset { + consumeOpts.MaxBytes = 0 + } + if consumeOpts.MaxMessages == unset { + consumeOpts.MaxMessages = DefaultMaxMessages + } + } + + if consumeOpts.ThresholdMessages == 0 { + consumeOpts.ThresholdMessages = int(math.Ceil(float64(consumeOpts.MaxMessages) / 2)) + } + if consumeOpts.ThresholdBytes == 0 { + consumeOpts.ThresholdBytes = int(math.Ceil(float64(consumeOpts.MaxBytes) / 2)) + } + if consumeOpts.Heartbeat == unset { + consumeOpts.Heartbeat = consumeOpts.Expires / 2 + if consumeOpts.Heartbeat > 30*time.Second { + consumeOpts.Heartbeat = 30 * time.Second + } + } + if consumeOpts.Heartbeat > consumeOpts.Expires/2 { + return fmt.Errorf("the value of Heartbeat must be less than 50%% of expiry") + } + return nil +} diff --git a/jsv2/jetstream/pull_test.go b/jetstream/pull_test.go similarity index 81% rename from jsv2/jetstream/pull_test.go rename to jetstream/pull_test.go index 6f7f630..e2784f9 100644 --- a/jsv2/jetstream/pull_test.go +++ b/jetstream/pull_test.go @@ -1,3 +1,16 @@ +// Copyright 2020-2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package jetstream import ( @@ -42,7 +55,7 @@ func TestPullConsumerFetch(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -53,21 +66,16 @@ func TestPullConsumerFetch(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - received := make([]Msg, 0) var i int for msg := range msgs.Messages() { - if msg == nil { - if len(testMsgs) != len(received) { - t.Fatalf("Invalid number of messages received; want: %d; got: %d", len(testMsgs), len(received)) - } - return - } if string(msg.Data()) != testMsgs[i] { t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) } - received = append(received, msg) i++ } + if len(testMsgs) != i { + t.Fatalf("Invalid number of messages received; want: %d; got: %d", len(testMsgs), i) + } if msgs.Error() != nil { t.Fatalf("Unexpected error during fetch: %v", msgs.Error()) } @@ -93,7 +101,7 @@ func TestPullConsumerFetch(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -161,7 +169,7 @@ func TestPullConsumerFetch(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -198,7 +206,7 @@ func TestPullConsumerFetch(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -225,46 +233,6 @@ func TestPullConsumerFetch(t *testing.T) { } }) - t.Run("with active streaming", func(t *testing.T) { - srv := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, srv) - nc, err := nats.Connect(srv.ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - js, err := New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() - - s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - _, err = c.Consume(func(_ Msg) {}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - _, err = c.Fetch(5) - if err == nil || !errors.Is(err, ErrConsumerHasActiveSubscription) { - t.Fatalf("Expected error: %v; got: %v", ErrConsumerHasActiveSubscription, err) - } - - _, err = c.FetchNoWait(5) - if err == nil || !errors.Is(err, ErrConsumerHasActiveSubscription) { - t.Fatalf("Expected error: %v; got: %v", ErrConsumerHasActiveSubscription, err) - } - }) - t.Run("with timeout", func(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) @@ -285,12 +253,12 @@ func TestPullConsumerFetch(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - msgs, err := c.Fetch(5, WithFetchTimeout(50*time.Millisecond)) + msgs, err := c.Fetch(5, FetchMaxWait(50*time.Millisecond)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -321,19 +289,210 @@ func TestPullConsumerFetch(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - _, err = c.Fetch(5, WithFetchTimeout(-50*time.Millisecond)) + _, err = c.Fetch(5, FetchMaxWait(-50*time.Millisecond)) if !errors.Is(err, ErrInvalidOption) { t.Fatalf("Expected error: %v; got: %v", ErrInvalidOption, err) } }) } -func TestPullConsumerNext_WithCluster(t *testing.T) { +func TestPullConsumerFetchBytes(t *testing.T) { + testSubject := "FOO.123" + msg := [10]byte{} + publishTestMsgs := func(t *testing.T, nc *nats.Conn, count int) { + for i := 0; i < count; i++ { + if err := nc.Publish(testSubject, msg[:]); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + } + + t.Run("no options, exact byte count received", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy, Name: "con"}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc, 5) + // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) + msgs, err := c.FetchBytes(300) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var i int + for msg := range msgs.Messages() { + msg.Ack() + i++ + } + if i != 5 { + t.Fatalf("Expected 5 messages; got: %d", i) + } + if msgs.Error() != nil { + t.Fatalf("Unexpected error during fetch: %v", msgs.Error()) + } + }) + + t.Run("no options, last msg does not fit max bytes", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy, Name: "con"}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc, 5) + // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) + msgs, err := c.FetchBytes(250) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var i int + for msg := range msgs.Messages() { + msg.Ack() + i++ + } + if i != 4 { + t.Fatalf("Expected 5 messages; got: %d", i) + } + if msgs.Error() != nil { + t.Fatalf("Unexpected error during fetch: %v", msgs.Error()) + } + }) + t.Run("no options, single msg is too large", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy, Name: "con"}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc, 5) + // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) + msgs, err := c.FetchBytes(30) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var i int + for msg := range msgs.Messages() { + msg.Ack() + i++ + } + if i != 0 { + t.Fatalf("Expected 5 messages; got: %d", i) + } + if msgs.Error() != nil { + t.Fatalf("Unexpected error during fetch: %v", msgs.Error()) + } + }) + + t.Run("timeout waiting for messages", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + js, err := New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy, Name: "con"}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + publishTestMsgs(t, nc, 5) + // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) + msgs, err := c.FetchBytes(1000, FetchMaxWait(50*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var i int + for msg := range msgs.Messages() { + msg.Ack() + i++ + } + if i != 5 { + t.Fatalf("Expected 5 messages; got: %d", i) + } + if msgs.Error() != nil { + t.Fatalf("Unexpected error during fetch: %v", msgs.Error()) + } + }) +} + +func TestPullConsumerFetch_WithCluster(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} publishTestMsgs := func(t *testing.T, nc *nats.Conn) { @@ -371,7 +530,7 @@ func TestPullConsumerNext_WithCluster(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -382,13 +541,11 @@ func TestPullConsumerNext_WithCluster(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - received := make([]Msg, 0) var i int for msg := range msgs.Messages() { if string(msg.Data()) != testMsgs[i] { t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) } - received = append(received, msg) i++ } if msgs.Error() != nil { @@ -416,7 +573,7 @@ func TestPullConsumerNext_WithCluster(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -466,7 +623,7 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -528,18 +685,13 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - // subscribe to next request subject to verify how many next requests were sent - sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name)) - if err != nil { - t.Fatalf("Error on subscribe: %v", err) - } msgs := make([]Msg, 0) - it, err := c.Messages(WithMessagesBatchSize(4)) + it, err := c.Messages(PullMaxMessages(3)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -559,15 +711,6 @@ func TestPullConsumerMessages(t *testing.T) { } it.Stop() time.Sleep(10 * time.Millisecond) - requestsNum, _, err := sub.Pending() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - // with batch size set to 4, and 5 messages published on subject, there should be a total of 5 requests sent - if requestsNum < 5 { - t.Fatalf("Unexpected number of requests sent; want at least 5; got %d", requestsNum) - } - if len(msgs) != len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) } @@ -598,7 +741,7 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -609,7 +752,7 @@ func TestPullConsumerMessages(t *testing.T) { } msgs := make([]Msg, 0) - it, err := c.Messages(WithMessagesMaxBytes(240)) + it, err := c.Messages(PullMaxBytes(60)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -668,7 +811,7 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -679,7 +822,7 @@ func TestPullConsumerMessages(t *testing.T) { } msgs := make([]Msg, 0) - it, err := c.Messages(WithMessagesMaxBytes(500)) + it, err := c.Messages(PullMaxBytes(150)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -737,7 +880,7 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -748,7 +891,7 @@ func TestPullConsumerMessages(t *testing.T) { } msgs := make([]Msg, 0) - it, err := c.Messages(WithMessagesBatchSize(1)) + it, err := c.Messages(PullMaxMessages(1)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -787,42 +930,6 @@ func TestPullConsumerMessages(t *testing.T) { } }) - t.Run("attempt iteration with active subscription twice on the same consumer", func(t *testing.T) { - srv := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, srv) - nc, err := nats.Connect(srv.ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - js, err := New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - _, err = c.Consume(func(msg Msg) {}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - _, err = c.Messages() - if err == nil || !errors.Is(err, ErrConsumerHasActiveSubscription) { - t.Fatalf("Expected error: %v; got: %v", ErrConsumerHasActiveSubscription, err) - } - }) - t.Run("create iterator, stop, then create again", func(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) @@ -843,7 +950,7 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -920,12 +1027,12 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - _, err = c.Messages(WithMessagesBatchSize(-1)) + _, err = c.Messages(PullMaxMessages(-1)) if err == nil || !errors.Is(err, ErrInvalidOption) { t.Fatalf("Expected error: %v; got: %v", ErrInvalidOption, err) } @@ -951,13 +1058,17 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } msgs := make([]Msg, 0) - it, err := c.Messages(WithMessagesHeartbeat(10 * time.Millisecond)) + // use custom function to bypass validation in test + it, err := c.Messages(pullOptFunc(func(o *consumeOpts) error { + o.Heartbeat = 10 * time.Millisecond + return nil + })) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1006,7 +1117,7 @@ func TestPullConsumerMessages(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1043,7 +1154,7 @@ func TestPullConsumerMessages(t *testing.T) { if len(msgs) != 2*len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) } - case <-errs: + case err := <-errs: t.Fatalf("Unexpected error: %s", err) } }) @@ -1080,7 +1191,7 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1129,20 +1240,41 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - l, err := c.Consume(func(msg Msg) {}) + wg := sync.WaitGroup{} + msgs1, msgs2 := make([]Msg, 0), make([]Msg, 0) + l1, err := c.Consume(func(msg Msg) { + msgs1 = append(msgs1, msg) + wg.Done() + msg.Ack() + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } - defer l.Stop() + defer l1.Stop() + l2, err := c.Consume(func(msg Msg) { + msgs2 = append(msgs2, msg) + wg.Done() + msg.Ack() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer l2.Stop() - _, err = c.Consume(func(msg Msg) {}) - if err == nil || !errors.Is(err, ErrConsumerHasActiveSubscription) { - t.Fatalf("Expected error: %v; got: %v", ErrConsumerHasActiveSubscription, err) + wg.Add(len(testMsgs)) + publishTestMsgs(t, nc) + wg.Wait() + + if len(msgs1)+len(msgs2) != len(testMsgs) { + t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs1)+len(msgs2)) + } + if len(msgs1) == 0 || len(msgs2) == 0 { + t.Fatalf("Received no messages on one of the subscriptions") } }) @@ -1166,7 +1298,7 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1235,15 +1367,10 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - // subscribe to next request subject to verify how many next requests were sent - sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name)) - if err != nil { - t.Fatalf("Error on subscribe: %v", err) - } msgs := make([]Msg, 0) wg := &sync.WaitGroup{} @@ -1251,7 +1378,7 @@ func TestPullConsumerConsume(t *testing.T) { l, err := c.Consume(func(msg Msg) { msgs = append(msgs, msg) wg.Done() - }, WithConsumeMaxMessages(4)) + }, PullMaxMessages(4)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1259,15 +1386,56 @@ func TestPullConsumerConsume(t *testing.T) { publishTestMsgs(t, nc) wg.Wait() - requestsNum, _, err := sub.Pending() + + if len(msgs) != len(testMsgs) { + t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) + } + for i, msg := range msgs { + if string(msg.Data()) != testMsgs[i] { + t.Fatalf("Invalid msg on index %d; expected: %s; got: %s", i, testMsgs[i], string(msg.Data())) + } + } + }) + + t.Run("fetch messages one by one", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) if err != nil { t.Fatalf("Unexpected error: %v", err) } - // with batch size set to 2, and 5 messages published on subject, there should be a total of 5 requests sent - if requestsNum != 5 { - t.Fatalf("Unexpected number of requests sent; want 3; got %d", requestsNum) + js, err := New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + msgs := make([]Msg, 0) + wg := &sync.WaitGroup{} + wg.Add(len(testMsgs)) + l, err := c.Consume(func(msg Msg) { + msgs = append(msgs, msg) + wg.Done() + }, PullMaxMessages(1)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer l.Stop() + + publishTestMsgs(t, nc) + wg.Wait() if len(msgs) != len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) @@ -1299,7 +1467,7 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1316,7 +1484,7 @@ func TestPullConsumerConsume(t *testing.T) { l, err := c.Consume(func(msg Msg) { msgs = append(msgs, msg) wg.Done() - }, WithConsumeMaxBytes(280)) + }, PullMaxBytes(150)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1329,7 +1497,7 @@ func TestPullConsumerConsume(t *testing.T) { } // new request should be sent after each consumed message (msg size is 57) - if requestsNum < 5 { + if requestsNum < 3 { t.Fatalf("Unexpected number of requests sent; want at least 5; got %d", requestsNum) } @@ -1363,13 +1531,13 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = c.Consume(func(_ Msg) { - }, WithConsumeMaxMessages(-1)) + }, PullMaxMessages(-1)) if err == nil || !errors.Is(err, ErrInvalidOption) { t.Fatalf("Expected error: %v; got: %v", ErrInvalidOption, err) } @@ -1395,42 +1563,26 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } - // subscribe to next request subject to verify how many next requests were sent - sub, err := nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.foo.%s", c.CachedInfo().Name)) - if err != nil { - t.Fatalf("Error on subscribe: %v", err) - } - msgs := make([]Msg, 0) wg := &sync.WaitGroup{} wg.Add(len(testMsgs)) l, err := c.Consume(func(msg Msg) { msgs = append(msgs, msg) wg.Done() - }, WithConsumeExpiry(50*time.Millisecond), WithConsumeHeartbeat(20*time.Millisecond)) + }, PullExpiry(2*time.Second)) if err != nil { t.Fatalf("Unexpected error: %v", err) } defer l.Stop() - time.Sleep(60 * time.Millisecond) publishTestMsgs(t, nc) wg.Wait() - requestsNum, _, err := sub.Pending() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // with expiry set to 50ms, and 60ms wait before messages are published, there should be a total of 2 requests sent to the server - if requestsNum < 2 { - t.Fatalf("Unexpected number of requests sent; want at least 2; got %d", requestsNum) - } if len(msgs) != len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) } @@ -1461,7 +1613,7 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1472,7 +1624,9 @@ func TestPullConsumerConsume(t *testing.T) { l, err := c.Consume(func(msg Msg) { msgs = append(msgs, msg) wg.Done() - }, WithConsumeExpiry(50*time.Millisecond), WithConsumeHeartbeat(20*time.Millisecond), WithConsumeErrHandler(func(consumeCtx ConsumeContext, err error) { + }, pullOptFunc(func(o *consumeOpts) error { + o.Expires = 50 * time.Millisecond + return nil })) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1516,13 +1670,13 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } _, err = c.Consume(func(_ Msg) { - }, WithConsumeExpiry(-1)) + }, PullExpiry(-1)) if err == nil || !errors.Is(err, ErrInvalidOption) { t.Fatalf("Expected error: %v; got: %v", ErrInvalidOption, err) } @@ -1548,7 +1702,7 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1559,7 +1713,7 @@ func TestPullConsumerConsume(t *testing.T) { l, err := c.Consume(func(msg Msg) { msgs = append(msgs, msg) wg.Done() - }, WithConsumeHeartbeat(10*time.Millisecond)) + }, PullMaxBytes(1*time.Second)) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1596,7 +1750,7 @@ func TestPullConsumerConsume(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1623,7 +1777,7 @@ func TestPullConsumerConsume(t *testing.T) { }) } -func TestPullConsumerStream_WithCluster(t *testing.T) { +func TestPullConsumerConsume_WithCluster(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} publishTestMsgs := func(t *testing.T, nc *nats.Conn) { @@ -1660,7 +1814,7 @@ func TestPullConsumerStream_WithCluster(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1708,7 +1862,7 @@ func TestPullConsumerStream_WithCluster(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/jsv2/jetstream/stream.go b/jetstream/stream.go similarity index 91% rename from jsv2/jetstream/stream.go rename to jetstream/stream.go index eb47604..2f93acc 100644 --- a/jsv2/jetstream/stream.go +++ b/jetstream/stream.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -21,6 +21,7 @@ import ( "time" "github.com/nats-io/nats.go" + "github.com/nats-io/nuid" ) type ( @@ -49,18 +50,26 @@ type ( } streamConsumerManager interface { - // CreateConsumer creates a consumer on a given stream with given config - // This operation is idempotent - if a consumer already exists, it will be a no-op (or error if configs do not match) - // Consumer interface is returned, serving as a hook to operate on a consumer (e.g. fetch messages) - CreateConsumer(context.Context, ConsumerConfig) (Consumer, error) - // UpdateConsumer updates an existing consumer - UpdateConsumer(context.Context, ConsumerConfig) (Consumer, error) + // AddConsumer creates a consumer on a given stream with given config. + // If consumer already exists, it will be updated (if possible). + // Consumer interface is returned, serving as a hook to operate on a consumer (e.g. fetch messages). + AddConsumer(context.Context, ConsumerConfig) (Consumer, error) + + // OrderedConsumer returns an OrderedConsumer instance. + // OrderedConsumer allows fetching messages from a stream (just like standard consumer), + // for in order delivery of messages. Underlying consumer is re-created when necessary, + // without additional client code. + OrderedConsumer(context.Context, OrderedConsumerConfig) (Consumer, error) + // Consumer returns a Consumer interface for an existing consumer Consumer(context.Context, string) (Consumer, error) + // DeleteConsumer removes a consumer DeleteConsumer(context.Context, string) error + // ListConsumers returns ConsumerInfoLister enabling iterating over a channel of consumer infos ListConsumers(context.Context) ConsumerInfoLister + // ConsumerNames returns a ConsumerNameLister enabling iterating over a channel of consumer names ConsumerNames(context.Context) ConsumerNameLister } @@ -180,31 +189,23 @@ type ( } ) -func (s *stream) CreateConsumer(ctx context.Context, cfg ConsumerConfig) (Consumer, error) { - if cfg.Durable != "" { - c, err := s.Consumer(ctx, cfg.Durable) - if err != nil && !errors.Is(err, ErrConsumerNotFound) { - return nil, err - } - if c != nil { - if err := compareConsumerConfig(&c.CachedInfo().Config, &cfg); err != nil { - return nil, fmt.Errorf("%w: %s", ErrConsumerNameAlreadyInUse, cfg.Durable) - } - return c, nil - } - } +func (s *stream) AddConsumer(ctx context.Context, cfg ConsumerConfig) (Consumer, error) { return upsertConsumer(ctx, s.jetStream, s.name, cfg) } -func (s *stream) UpdateConsumer(ctx context.Context, cfg ConsumerConfig) (Consumer, error) { - if cfg.Durable == "" { - return nil, ErrConsumerNameRequired +func (s *stream) OrderedConsumer(ctx context.Context, cfg OrderedConsumerConfig) (Consumer, error) { + oc := &orderedConsumer{ + jetStream: s.jetStream, + cfg: &cfg, + stream: s.name, + namePrefix: nuid.Next(), + doReset: make(chan struct{}, 1), } - _, err := s.Consumer(ctx, cfg.Durable) - if err != nil { - return nil, err + if cfg.OptStartSeq != 0 { + oc.cursor.streamSeq = cfg.OptStartSeq - 1 } - return upsertConsumer(ctx, s.jetStream, s.name, cfg) + + return oc, nil } func (s *stream) Consumer(ctx context.Context, name string) (Consumer, error) { diff --git a/jsv2/jetstream/stream_config.go b/jetstream/stream_config.go similarity index 99% rename from jsv2/jetstream/stream_config.go rename to jetstream/stream_config.go index 18ce624..69dc9df 100644 --- a/jsv2/jetstream/stream_config.go +++ b/jetstream/stream_config.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/jsv2/jetstream/stream_test.go b/jetstream/stream_test.go similarity index 84% rename from jsv2/jetstream/stream_test.go rename to jetstream/stream_test.go index 06f95de..c7120d0 100644 --- a/jsv2/jetstream/stream_test.go +++ b/jetstream/stream_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2022 The NATS Authors +// Copyright 2020-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -25,7 +25,7 @@ import ( "github.com/nats-io/nats.go" ) -func TestCreateConsumer(t *testing.T) { +func TestAddConsumer(t *testing.T) { tests := []struct { name string consumerConfig ConsumerConfig @@ -43,13 +43,13 @@ func TestCreateConsumer(t *testing.T) { shouldCreate: true, }, { - name: "consumer already exists, idempotent operation", - consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckExplicitPolicy}, + name: "consumer already exists, update", + consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckExplicitPolicy, Description: "test consumer"}, }, { - name: "consumer already exists, config mismatch", - consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckExplicitPolicy, Description: "test"}, - withError: ErrConsumerNameAlreadyInUse, + name: "consumer already exists, illegal update", + consumerConfig: ConsumerConfig{Durable: "dur", AckPolicy: AckNonePolicy}, + withError: ErrConsumerCreate, }, { name: "invalid durable name", @@ -81,12 +81,12 @@ func TestCreateConsumer(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var sub *nats.Subscription - if test.consumerConfig.Durable != "" { - sub, err = nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.DURABLE.CREATE.foo.%s", test.consumerConfig.Durable)) + if test.consumerConfig.FilterSubject != "" { + sub, err = nc.SubscribeSync(fmt.Sprintf("$JS.API.CONSUMER.CREATE.foo.*.%s", test.consumerConfig.FilterSubject)) } else { - sub, err = nc.SubscribeSync("$JS.API.CONSUMER.CREATE.foo") + sub, err = nc.SubscribeSync("$JS.API.CONSUMER.CREATE.foo.*") } - c, err := s.CreateConsumer(ctx, test.consumerConfig) + c, err := s.AddConsumer(ctx, test.consumerConfig) if test.withError != nil { if err == nil || !errors.Is(err, test.withError) { t.Fatalf("Expected error: %v; got: %v", test.withError, err) @@ -109,114 +109,6 @@ func TestCreateConsumer(t *testing.T) { } } -func TestCreateConsumer_WithCluster(t *testing.T) { - name := "cluster" - stream := StreamConfig{ - Name: name, - Replicas: 1, - Subjects: []string{"FOO.*"}, - } - t.Run("consumer name conflict", func(t *testing.T) { - withJSClusterAndStream(t, name, 3, stream, func(t *testing.T, subject string, srvs ...*jsServer) { - nc, err := nats.Connect(srvs[0].ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - js, err := New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s, err := js.Stream(ctx, stream.Name) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "test"}) - if err == nil || !errors.Is(err, ErrConsumerNameAlreadyInUse) { - t.Fatalf("Expected error: %v; got %v", ErrConsumerNameAlreadyInUse, err) - } - }) - }) -} - -func TestUpdateConsumer(t *testing.T) { - tests := []struct { - name string - durable string - withError error - }{ - { - name: "update consumer", - durable: "dur", - }, - { - name: "consumer does not exist", - durable: "abc", - withError: ErrConsumerNotFound, - }, - { - name: "invalid durable name", - durable: "dur.123", - withError: ErrInvalidConsumerName, - }, - } - - srv := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, srv) - nc, err := nats.Connect(srv.ClientURL()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - js, err := New(nc) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer nc.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - s, err := js.CreateStream(ctx, StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c, err := s.UpdateConsumer(ctx, ConsumerConfig{Durable: test.durable, AckPolicy: AckAllPolicy, Description: test.name}) - if test.withError != nil { - if err == nil || !errors.Is(err, test.withError) { - t.Fatalf("Expected error: %v; got: %v", test.withError, err) - } - return - } - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - c, err = s.Consumer(ctx, c.CachedInfo().Name) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if c.CachedInfo().Config.Description != test.name { - t.Fatalf("Invalid consumer description after update; want: %s; got: %s", test.name, c.CachedInfo().Config.Description) - } - }) - } -} - func TestConsumer(t *testing.T) { tests := []struct { name string @@ -258,7 +150,7 @@ func TestConsumer(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) + _, err = s.AddConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -323,7 +215,7 @@ func TestDeleteConsumer(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - _, err = s.CreateConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) + _, err = s.AddConsumer(ctx, ConsumerConfig{Durable: "dur", AckPolicy: AckAllPolicy, Description: "desc"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -794,7 +686,7 @@ func TestListConsumers(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } for i := 0; i < test.consumersNum; i++ { - _, err = s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + _, err = s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -857,7 +749,7 @@ func TestConsumerNames(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } for i := 0; i < test.consumersNum; i++ { - _, err = s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + _, err = s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -963,7 +855,7 @@ func TestPurgeStream(t *testing.T) { } return } - c, err := s.CreateConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) + c, err := s.AddConsumer(ctx, ConsumerConfig{AckPolicy: AckExplicitPolicy}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -977,8 +869,7 @@ func TestPurgeStream(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - var msg Msg - msg = <-msgs.Messages() + msg := <-msgs.Messages() if msg == nil { break Loop } diff --git a/jsv2/jetstream/consumer.go b/jsv2/jetstream/consumer.go deleted file mode 100644 index e1489d1..0000000 --- a/jsv2/jetstream/consumer.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2020-2022 The NATS Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jetstream - -import ( - "context" - "encoding/json" - "fmt" - "strings" -) - -type ( - - // Consumer contains methods for fetching/processing messages from a stream, as well as fetching consumer info - Consumer interface { - // Fetch is used to retrieve up to a provided number of messages from a stream. - // This method will always send a single request and wait until either all messages are retreived - // or context reaches its deadline. - Fetch(int, ...FetchOpt) (MessageBatch, error) - // FetchNoWait is used to retrieve up to a provided number of messages from a stream. - // This method will always send a single request and immediately return up to a provided number of messages - FetchNoWait(batch int) (MessageBatch, error) - // Consume can be used to continuously receive messages and handle them with the provided callback function - Consume(MessageHandler, ...ConsumeOpts) (ConsumeContext, error) - // Messages returns [MessagesContext], allowing continously iterating over messages on a stream. - Messages(...ConsumerMessagesOpts) (MessagesContext, error) - - // Info returns Consumer details - Info(context.Context) (*ConsumerInfo, error) - // CachedInfo returns [*ConsumerInfo] cached on a consumer struct - CachedInfo() *ConsumerInfo - } -) - -// Info returns [ConsumerInfo] for a given consumer -func (p *pullConsumer) Info(ctx context.Context) (*ConsumerInfo, error) { - infoSubject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiConsumerInfoT, p.stream, p.name)) - var resp consumerInfoResponse - - if _, err := p.jetStream.apiRequestJSON(ctx, infoSubject, &resp); err != nil { - return nil, err - } - if resp.Error != nil { - if resp.Error.ErrorCode == JSErrCodeConsumerNotFound { - return nil, ErrConsumerNotFound - } - return nil, resp.Error - } - - p.info = resp.ConsumerInfo - return resp.ConsumerInfo, nil -} - -// CachedInfo returns [ConsumerInfo] fetched when initializing/updating a consumer -// -// NOTE: The returned object might not be up to date with the most recent updates on the server -// For up-to-date information, use [Info] -func (p *pullConsumer) CachedInfo() *ConsumerInfo { - return p.info -} - -func upsertConsumer(ctx context.Context, js *jetStream, stream string, cfg ConsumerConfig) (Consumer, error) { - req := createConsumerRequest{ - Stream: stream, - Config: &cfg, - } - reqJSON, err := json.Marshal(req) - if err != nil { - return nil, err - } - - var ccSubj string - if cfg.Durable != "" { - if err := validateDurableName(cfg.Durable); err != nil { - return nil, err - } - ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiDurableCreateT, stream, cfg.Durable)) - } else { - ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerCreateT, stream)) - } - var resp consumerInfoResponse - - if _, err := js.apiRequestJSON(ctx, ccSubj, &resp, reqJSON); err != nil { - return nil, err - } - if resp.Error != nil { - if resp.Error.ErrorCode == JSErrCodeStreamNotFound { - return nil, ErrStreamNotFound - } - return nil, resp.Error - } - - return &pullConsumer{ - jetStream: js, - stream: stream, - name: resp.Name, - durable: cfg.Durable != "", - info: resp.ConsumerInfo, - }, nil -} - -func getConsumer(ctx context.Context, js *jetStream, stream, name string) (Consumer, error) { - if err := validateDurableName(name); err != nil { - return nil, err - } - infoSubject := apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerInfoT, stream, name)) - - var resp consumerInfoResponse - - if _, err := js.apiRequestJSON(ctx, infoSubject, &resp); err != nil { - return nil, err - } - if resp.Error != nil { - if resp.Error.ErrorCode == JSErrCodeConsumerNotFound { - return nil, ErrConsumerNotFound - } - return nil, resp.Error - } - - return &pullConsumer{ - jetStream: js, - stream: stream, - name: name, - durable: resp.Config.Durable != "", - info: resp.ConsumerInfo, - }, nil -} - -func deleteConsumer(ctx context.Context, js *jetStream, stream, consumer string) error { - if err := validateDurableName(consumer); err != nil { - return err - } - deleteSubject := apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerDeleteT, stream, consumer)) - - var resp consumerDeleteResponse - - if _, err := js.apiRequestJSON(ctx, deleteSubject, &resp); err != nil { - return err - } - if resp.Error != nil { - if resp.Error.ErrorCode == JSErrCodeConsumerNotFound { - return ErrConsumerNotFound - } - return resp.Error - } - return nil -} - -func validateDurableName(dur string) error { - if strings.Contains(dur, ".") { - return fmt.Errorf("%w: '%s'", ErrInvalidConsumerName, dur) - } - return nil -} - -func compareConsumerConfig(s, u *ConsumerConfig) error { - makeErr := func(fieldName string, usrVal, srvVal interface{}) error { - return fmt.Errorf("configuration requests %s to be %v, but consumer's value is %v", fieldName, usrVal, srvVal) - } - - if u.Durable != s.Durable { - return makeErr("durable", u.Durable, s.Durable) - } - if u.Description != s.Description { - return makeErr("description", u.Description, s.Description) - } - if u.DeliverPolicy != s.DeliverPolicy { - return makeErr("deliver policy", u.DeliverPolicy, s.DeliverPolicy) - } - if u.OptStartSeq != s.OptStartSeq { - return makeErr("optional start sequence", u.OptStartSeq, s.OptStartSeq) - } - if u.OptStartTime != nil && !u.OptStartTime.IsZero() && !(*u.OptStartTime).Equal(*s.OptStartTime) { - return makeErr("optional start time", u.OptStartTime, s.OptStartTime) - } - if u.AckPolicy != s.AckPolicy { - return makeErr("ack policy", u.AckPolicy, s.AckPolicy) - } - if u.AckWait != 0 && u.AckWait != s.AckWait { - return makeErr("ack wait", u.AckWait.String(), s.AckWait.String()) - } - if !(u.MaxDeliver == 0 && s.MaxDeliver == -1) && u.MaxDeliver != s.MaxDeliver { - return makeErr("max deliver", u.MaxDeliver, s.MaxDeliver) - } - if len(u.BackOff) != len(s.BackOff) { - return makeErr("backoff", u.BackOff, s.BackOff) - } - for i, val := range u.BackOff { - if val != s.BackOff[i] { - return makeErr("backoff", u.BackOff, s.BackOff) - } - } - if u.FilterSubject != s.FilterSubject { - return makeErr("filter subject", u.FilterSubject, s.FilterSubject) - } - if u.ReplayPolicy != s.ReplayPolicy { - return makeErr("replay policy", u.ReplayPolicy, s.ReplayPolicy) - } - if u.RateLimit != s.RateLimit { - return makeErr("rate limit", u.RateLimit, s.RateLimit) - } - if u.SampleFrequency != s.SampleFrequency { - return makeErr("sample frequency", u.SampleFrequency, s.SampleFrequency) - } - if u.MaxWaiting != 0 && u.MaxWaiting != s.MaxWaiting { - return makeErr("max waiting", u.MaxWaiting, s.MaxWaiting) - } - if u.MaxAckPending != 0 && u.MaxAckPending != s.MaxAckPending { - return makeErr("max ack pending", u.MaxAckPending, s.MaxAckPending) - } - if u.FlowControl != s.FlowControl { - return makeErr("flow control", u.FlowControl, s.FlowControl) - } - if u.Heartbeat != s.Heartbeat { - return makeErr("heartbeat", u.Heartbeat, s.Heartbeat) - } - if u.HeadersOnly != s.HeadersOnly { - return makeErr("headers only", u.HeadersOnly, s.HeadersOnly) - } - if u.MaxRequestBatch != s.MaxRequestBatch { - return makeErr("max request batch", u.MaxRequestBatch, s.MaxRequestBatch) - } - if u.MaxRequestExpires != s.MaxRequestExpires { - return makeErr("max request expires", u.MaxRequestExpires.String(), s.MaxRequestExpires.String()) - } - if u.DeliverSubject != s.DeliverSubject { - return makeErr("deliver subject", u.DeliverSubject, s.DeliverSubject) - } - if u.DeliverGroup != s.DeliverGroup { - return makeErr("deliver group", u.DeliverSubject, s.DeliverSubject) - } - if u.InactiveThreshold != s.InactiveThreshold { - return makeErr("inactive threshhold", u.InactiveThreshold.String(), s.InactiveThreshold.String()) - } - if u.Replicas != s.Replicas { - return makeErr("replicas", u.Replicas, s.Replicas) - } - if u.MemoryStorage != s.MemoryStorage { - return makeErr("memory storage", u.MemoryStorage, s.MemoryStorage) - } - return nil -} diff --git a/jsv2/jetstream/pull.go b/jsv2/jetstream/pull.go deleted file mode 100644 index 9745561..0000000 --- a/jsv2/jetstream/pull.go +++ /dev/null @@ -1,722 +0,0 @@ -// Copyright 2020-2022 The NATS Authors -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package jetstream - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "math" - "sync" - "sync/atomic" - "time" - - "github.com/nats-io/nats.go" -) - -type ( - // MessagesContext supports iterating over a messages on a stream. - MessagesContext interface { - // Next retreives nest message on a stream. It will block until the next message is available. - Next() (Msg, error) - // Stop closes the iterator and cancels subscription. - Stop() - } - - ConsumeContext interface { - Stop() - } - - // MessageHandler is a handler function used as callback in [Consume] - MessageHandler func(msg Msg) - - // ConsumeOpts represent additional options used in [Consume] for pull consumers - ConsumeOpts func(*consumeOpts) error - - // ConsumerMessagesOpts represent additional options used in [Messages] for pull consumers - ConsumerMessagesOpts func(*consumeOpts) error - - pullConsumer struct { - sync.Mutex - jetStream *jetStream - stream string - durable bool - name string - info *ConsumerInfo - isSubscribed uint32 - } - - pullRequest struct { - Expires time.Duration `json:"expires,omitempty"` - Batch int `json:"batch,omitempty"` - MaxBytes int `json:"max_bytes,omitempty"` - NoWait bool `json:"no_wait,omitempty"` - Heartbeat time.Duration `json:"idle_heartbeat,omitempty"` - } - - consumeOpts struct { - Expires time.Duration - MaxMessages int - MaxBytes int - Heartbeat time.Duration - ErrHandler ConsumeErrHandler - } - - ConsumeErrHandler func(consumeCtx ConsumeContext, err error) - - pullSubscription struct { - sync.Mutex - consumer *pullConsumer - subscription *nats.Subscription - req *pullRequest - msgs chan *nats.Msg - errs chan error - pending pendingMsgs - hbMonitor *hbMonitor - fetchInProgress uint32 - closed uint32 - done chan struct{} - reconnected chan struct{} - disconnected chan struct{} - fetchNext chan struct{} - reconnectHandler nats.ConnHandler - disconnectHandler nats.ConnErrHandler - consumeOpts *consumeOpts - } - - pendingMsgs struct { - msgCount int - byteCount int - } - - MessageBatch interface { - Messages() <-chan Msg - Error() error - } - - fetchResult struct { - msgs chan Msg - err error - } - - FetchOpt func(*pullRequest) error - - hbMonitor struct { - timer *time.Timer - sync.Mutex - } -) - -const ( - DefaultBatchSize = 100 - DefaultExpires = 30 * time.Second - DefaultHeartbeat = 15 * time.Second - DefaultThreshold = 0.75 -) - -// Messages returns MessagesContext, allowing continuously iterating over messages on a stream. -// -// Available options: -// [WithMessagesMaxMessages] - sets maximum number of messages stored in a buffer, default is set to 100 -// [WithMessagesMaxBytes] - sets maximum number of bytes stored in a buffer -// [WithMessagesHeartbeat] - sets an idle heartbeat setting for a pull request, default value is 5 seconds. -func (p *pullConsumer) Messages(opts ...ConsumerMessagesOpts) (MessagesContext, error) { - if atomic.LoadUint32(&p.isSubscribed) == 1 { - return nil, ErrConsumerHasActiveSubscription - } - atomic.StoreUint32(&p.isSubscribed, 1) - // threshold := DefaultThreshold - consumeOpts := &consumeOpts{ - MaxMessages: DefaultBatchSize, - Expires: DefaultExpires, - Heartbeat: DefaultHeartbeat, - } - for _, opt := range opts { - if err := opt(consumeOpts); err != nil { - return nil, err - } - } - req := &pullRequest{ - Expires: consumeOpts.Expires, - Batch: int(math.Ceil(float64(consumeOpts.MaxMessages) / 4)), - MaxBytes: consumeOpts.MaxBytes / 4, - Heartbeat: consumeOpts.Heartbeat, - } - subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) - - msgs := make(chan *nats.Msg, consumeOpts.MaxMessages) - - sub := &pullSubscription{ - consumer: p, - req: req, - done: make(chan struct{}, 1), - msgs: msgs, - errs: make(chan error, 1), - fetchNext: make(chan struct{}, 1), - reconnected: make(chan struct{}), - disconnected: make(chan struct{}), - reconnectHandler: p.jetStream.conn.ReconnectHandler(), - disconnectHandler: p.jetStream.conn.DisconnectErrHandler(), - consumeOpts: consumeOpts, - } - p.jetStream.conn.SetReconnectHandler(func(c *nats.Conn) { - if sub.reconnectHandler != nil { - sub.reconnectHandler(p.jetStream.conn) - } - sub.reconnected <- struct{}{} - }) - p.jetStream.conn.SetDisconnectErrHandler(func(c *nats.Conn, err error) { - if sub.disconnectHandler != nil { - sub.disconnectHandler(p.jetStream.conn, err) - } - sub.disconnected <- struct{}{} - }) - inbox := nats.NewInbox() - var err error - sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs) - if err != nil { - return nil, err - } - - sub.hbMonitor = sub.scheduleHeartbeatCheck(req.Heartbeat) - go func() { - <-sub.done - sub.cleanupSubscriptionAndRestoreConnHandler() - }() - - // initial pull - if err := sub.pull(*req, subject); err != nil { - sub.errs <- err - } - sub.pending.msgCount = req.Batch - sub.pending.byteCount = req.MaxBytes - go sub.pullMessages(subject) - - return sub, nil -} - -func (s *pullSubscription) Next() (Msg, error) { - s.Lock() - defer s.Unlock() - if atomic.LoadUint32(&s.closed) == 1 { - return nil, ErrMsgIteratorClosed - } - threshold := DefaultThreshold - - for { - if float64(s.pending.msgCount) <= float64(s.consumeOpts.MaxMessages)*threshold || - (float64(s.pending.byteCount) <= float64(s.consumeOpts.MaxBytes)*threshold && s.req.MaxBytes != 0) && - atomic.LoadUint32(&s.fetchInProgress) == 1 { - - s.pending.msgCount += s.req.Batch - if s.req.MaxBytes > 0 { - s.pending.byteCount += s.req.MaxBytes - } - s.fetchNext <- struct{}{} - } - select { - case msg := <-s.msgs: - if s.hbMonitor != nil { - s.hbMonitor.Reset(2 * s.req.Heartbeat) - } - userMsg, err := checkMsg(msg) - if !userMsg { - // heartbeat message - if err == nil { - continue - } - if !errors.Is(err, nats.ErrTimeout) && !errors.Is(err, ErrMaxBytesExceeded) { - if s.consumeOpts.ErrHandler != nil { - s.consumeOpts.ErrHandler(s, err) - } - if errors.Is(err, ErrConsumerDeleted) || errors.Is(err, ErrBadRequest) { - s.Stop() - return nil, err - } - if errors.Is(err, ErrConsumerLeadershipChanged) { - s.pending.msgCount = 0 - s.pending.byteCount = 0 - } - continue - } - msgsLeft, bytesLeft, err := parsePending(msg) - if err != nil { - if s.consumeOpts.ErrHandler != nil { - s.consumeOpts.ErrHandler(s, err) - } - } - s.pending.msgCount -= msgsLeft - if s.pending.msgCount < 0 { - s.pending.msgCount = 0 - } - if s.req.MaxBytes > 0 { - s.pending.byteCount -= bytesLeft - if s.pending.byteCount < 0 { - s.pending.byteCount = 0 - } - } - continue - } - s.pending.msgCount-- - if s.req.MaxBytes > 0 { - s.pending.byteCount -= msgSize(msg) - } - return s.consumer.jetStream.toJSMsg(msg), nil - case <-s.disconnected: - if s.hbMonitor != nil { - s.hbMonitor.Stop() - } - case <-s.reconnected: - // try fetching consumer info several times to make sure consumer is available after reconnect - for i := 0; i < 5; i++ { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _, err := s.consumer.Info(ctx) - cancel() - if err == nil { - break - } - if err != nil { - if i == 4 { - s.cleanupSubscriptionAndRestoreConnHandler() - if s.consumeOpts.ErrHandler != nil { - s.consumeOpts.ErrHandler(s, err) - } - return nil, err - } - } - time.Sleep(5 * time.Second) - } - s.pending.msgCount = 0 - s.pending.byteCount = 0 - case err := <-s.errs: - if s.consumeOpts.ErrHandler != nil { - s.consumeOpts.ErrHandler(s, err) - } - if errors.Is(err, ErrNoHeartbeat) { - s.pending.msgCount = 0 - s.pending.byteCount = 0 - } - } - } -} - -func (hb *hbMonitor) Stop() { - hb.Mutex.Lock() - hb.timer.Stop() - hb.Mutex.Unlock() -} - -func (hb *hbMonitor) Reset(dur time.Duration) { - hb.Mutex.Lock() - hb.timer.Reset(dur) - hb.Mutex.Unlock() -} - -func (s *pullSubscription) Stop() { - if atomic.LoadUint32(&s.closed) == 1 { - return - } - close(s.done) - atomic.StoreUint32(&s.consumer.isSubscribed, 0) - atomic.StoreUint32(&s.closed, 1) -} - -// Fetch sends a single request to retrieve given number of messages. -// It will wait up to provided expiry time if not all messages are available. -func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) { - p.Lock() - if atomic.LoadUint32(&p.isSubscribed) == 1 { - p.Unlock() - return nil, ErrConsumerHasActiveSubscription - } - req := &pullRequest{ - Batch: batch, - Expires: DefaultExpires, - } - for _, opt := range opts { - if err := opt(req); err != nil { - return nil, err - } - } - // for longer pulls, set heartbeat value - if req.Expires >= 10*time.Second { - req.Heartbeat = 5 * time.Second - } - p.Unlock() - - return p.fetch(req) - -} - -// Fetch sends a single request to retrieve given number of messages. -// If there are any messages available at the time of sending request, -// FetchNoWait will return immediately. -func (p *pullConsumer) FetchNoWait(batch int) (MessageBatch, error) { - p.Lock() - if atomic.LoadUint32(&p.isSubscribed) == 1 { - p.Unlock() - return nil, ErrConsumerHasActiveSubscription - } - req := &pullRequest{ - Batch: batch, - NoWait: true, - } - p.Unlock() - - return p.fetch(req) -} - -func (p *pullConsumer) fetch(req *pullRequest) (MessageBatch, error) { - res := &fetchResult{ - msgs: make(chan Msg, req.Batch), - } - msgs := make(chan *nats.Msg, 2*req.Batch) - subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) - - sub := &pullSubscription{ - consumer: p, - req: req, - done: make(chan struct{}, 1), - msgs: msgs, - errs: make(chan error, 1), - fetchNext: make(chan struct{}, 1), - reconnected: make(chan struct{}), - } - inbox := nats.NewInbox() - var err error - sub.subscription, err = p.jetStream.conn.ChanSubscribe(inbox, sub.msgs) - if err != nil { - return nil, err - } - if err := sub.pull(*req, subject); err != nil { - return nil, err - } - - var received int - hbTimer := sub.scheduleHeartbeatCheck(req.Heartbeat) - go func(res *fetchResult) { - defer sub.subscription.Unsubscribe() - defer close(res.msgs) - for { - if received == req.Batch { - return - } - select { - case msg := <-msgs: - if hbTimer != nil { - hbTimer.Reset(2 * req.Heartbeat) - } - userMsg, err := checkMsg(msg) - if err != nil { - if !errors.Is(err, nats.ErrTimeout) && !errors.Is(err, ErrNoMessages) { - res.err = err - return - } - return - } - if !userMsg { - continue - } - res.msgs <- p.jetStream.toJSMsg(msg) - received++ - case <-time.After(req.Expires + 5*time.Second): - res.err = fmt.Errorf("fetch timed out") - return - } - } - }(res) - return res, nil -} - -func (fr *fetchResult) Messages() <-chan Msg { - return fr.msgs -} - -func (fr *fetchResult) Error() error { - return fr.err -} - -// Consume returns a ConsumeContext, allowing for processing incoming messages from a stream in a given callback function. -// -// Available options: -// [WithConsumeMaxMessages] - sets maximum number of messages stored in a buffer, default is set to 100 -// [WithConsumeMaxBytes] - sets maximum number of bytes stored in a buffer -// [WitConsumeExpiry] - sets a timeout for individual batch request, default is set to 30 seconds -// [WithConsumeHeartbeat] - sets an idle heartbeat setting for a pull request, default is set to 5s -// [WithConsumeErrHandler] - sets custom consume error callback handler -func (p *pullConsumer) Consume(handler MessageHandler, opts ...ConsumeOpts) (ConsumeContext, error) { - if atomic.LoadUint32(&p.isSubscribed) == 1 { - return nil, ErrConsumerHasActiveSubscription - } - if handler == nil { - return nil, ErrHandlerRequired - } - threshold := DefaultThreshold - consumeOpts := &consumeOpts{ - MaxMessages: DefaultBatchSize, - Expires: DefaultExpires, - Heartbeat: DefaultHeartbeat, - } - for _, opt := range opts { - if err := opt(consumeOpts); err != nil { - return nil, err - } - } - req := &pullRequest{ - Expires: consumeOpts.Expires, - Batch: int(math.Ceil(float64(consumeOpts.MaxMessages) / 4)), - MaxBytes: consumeOpts.MaxBytes / 4, - Heartbeat: consumeOpts.Heartbeat, - } - - subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) - - atomic.StoreUint32(&p.isSubscribed, 1) - sub := &pullSubscription{ - consumer: p, - req: req, - errs: make(chan error, 1), - done: make(chan struct{}, 1), - fetchNext: make(chan struct{}, 1), - reconnected: make(chan struct{}), - disconnected: make(chan struct{}), - reconnectHandler: p.jetStream.conn.ReconnectHandler(), - disconnectHandler: p.jetStream.conn.DisconnectErrHandler(), - consumeOpts: consumeOpts, - } - - p.jetStream.conn.SetReconnectHandler(func(c *nats.Conn) { - if sub.reconnectHandler != nil { - sub.reconnectHandler(p.jetStream.conn) - } - sub.reconnected <- struct{}{} - }) - p.jetStream.conn.SetDisconnectErrHandler(func(c *nats.Conn, err error) { - if sub.disconnectHandler != nil { - sub.disconnectHandler(p.jetStream.conn, err) - } - sub.disconnected <- struct{}{} - }) - sub.hbMonitor = sub.scheduleHeartbeatCheck(req.Heartbeat) - - internalHandler := func(msg *nats.Msg) { - if sub.hbMonitor != nil { - sub.hbMonitor.Reset(2 * req.Heartbeat) - } - userMsg, err := checkMsg(msg) - if !userMsg && err == nil { - return - } - defer func() { - if float64(sub.pending.msgCount) <= float64(consumeOpts.MaxMessages)*threshold || - (float64(sub.pending.byteCount) <= float64(consumeOpts.MaxBytes)*threshold && sub.req.MaxBytes != 0) && - atomic.LoadUint32(&sub.fetchInProgress) == 1 { - - sub.pending.msgCount += req.Batch - if sub.req.MaxBytes != 0 { - sub.pending.byteCount += req.MaxBytes - } - sub.fetchNext <- struct{}{} - } - }() - if !userMsg { - // heartbeat message - if err == nil { - return - } - if !errors.Is(err, nats.ErrTimeout) && !errors.Is(err, ErrMaxBytesExceeded) { - if sub.consumeOpts.ErrHandler != nil { - sub.consumeOpts.ErrHandler(sub, err) - } - if errors.Is(err, ErrConsumerDeleted) || errors.Is(err, ErrBadRequest) { - sub.Stop() - } - if errors.Is(err, ErrConsumerLeadershipChanged) { - sub.pending.msgCount = 0 - sub.pending.byteCount = 0 - } - return - } - msgsLeft, bytesLeft, err := parsePending(msg) - if err != nil { - if sub.consumeOpts.ErrHandler != nil { - sub.consumeOpts.ErrHandler(sub, err) - } - } - sub.pending.msgCount -= msgsLeft - if sub.pending.msgCount < 0 { - sub.pending.msgCount = 0 - } - if sub.req.MaxBytes > 0 { - sub.pending.byteCount -= bytesLeft - if sub.pending.byteCount < 0 { - sub.pending.byteCount = 0 - } - } - return - } - handler(p.jetStream.toJSMsg(msg)) - sub.pending.msgCount-- - if sub.req.MaxBytes != 0 { - sub.pending.byteCount -= msgSize(msg) - } - } - inbox := nats.NewInbox() - var err error - sub.subscription, err = p.jetStream.conn.Subscribe(inbox, internalHandler) - if err != nil { - return nil, err - } - - // initial pull - sub.pending.msgCount = sub.req.Batch - sub.pending.byteCount = sub.req.MaxBytes - if err := sub.pull(*req, subject); err != nil { - sub.errs <- err - } - - go func() { - for { - if atomic.LoadUint32(&sub.closed) == 1 { - return - } - select { - case <-sub.disconnected: - if sub.hbMonitor != nil { - sub.hbMonitor.Stop() - } - case <-sub.reconnected: - // try fetching consumer info several times to make sure consumer is available after reconnect - for i := 0; i < 5; i++ { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _, err := p.Info(ctx) - cancel() - if err == nil { - break - } - if err != nil { - if i == 4 { - sub.cleanupSubscriptionAndRestoreConnHandler() - if sub.consumeOpts.ErrHandler != nil { - sub.consumeOpts.ErrHandler(sub, err) - } - return - } - } - time.Sleep(5 * time.Second) - } - - sub.pending.msgCount = req.Batch - sub.pending.byteCount = req.MaxBytes - sub.fetchNext <- struct{}{} - case err := <-sub.errs: - if sub.consumeOpts.ErrHandler != nil { - sub.consumeOpts.ErrHandler(sub, err) - } - if errors.Is(err, ErrNoHeartbeat) { - sub.pending.msgCount = 0 - sub.pending.byteCount = 0 - sub.pending.msgCount += req.Batch - if sub.req.MaxBytes != 0 { - sub.pending.byteCount += req.MaxBytes - } - sub.fetchNext <- struct{}{} - } - } - } - }() - - go sub.pullMessages(subject) - - return sub, nil -} - -func (s *pullSubscription) pullMessages(subject string) { - for { - select { - case <-s.fetchNext: - atomic.StoreUint32(&s.fetchInProgress, 1) - if err := s.pull(*s.req, subject); err != nil { - if errors.Is(err, ErrMsgIteratorClosed) { - s.cleanupSubscriptionAndRestoreConnHandler() - return - } - s.errs <- err - } - atomic.StoreUint32(&s.fetchInProgress, 0) - case <-s.done: - s.cleanupSubscriptionAndRestoreConnHandler() - return - } - } -} - -func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor { - if dur == 0 { - return nil - } - return &hbMonitor{ - timer: time.AfterFunc(2*dur, func() { - s.errs <- ErrNoHeartbeat - }), - } -} - -func (s *pullSubscription) cleanupSubscriptionAndRestoreConnHandler() { - s.consumer.Lock() - defer s.consumer.Unlock() - if s.subscription == nil { - return - } - if s.hbMonitor != nil { - s.hbMonitor.Stop() - } - s.subscription.Unsubscribe() - s.subscription = nil - atomic.StoreUint32(&s.consumer.isSubscribed, 0) - s.consumer.jetStream.conn.SetDisconnectErrHandler(s.disconnectHandler) - s.consumer.jetStream.conn.SetReconnectHandler(s.reconnectHandler) -} - -func msgSize(msg *nats.Msg) int { - if msg == nil { - return 0 - } - size := len(msg.Subject) + len(msg.Reply) + len(msg.Data) - return size -} - -// pull sends a pull request to the server and waits for messages using a subscription from [pullSubscription]. -// Messages will be fetched up to given batch_size or until there are no more messages or timeout is returned -func (s *pullSubscription) pull(req pullRequest, subject string) error { - s.consumer.Lock() - defer s.consumer.Unlock() - if atomic.LoadUint32(&s.closed) == 1 { - return ErrMsgIteratorClosed - } - if req.Batch < 1 { - return fmt.Errorf("%w: batch size must be at least 1", nats.ErrInvalidArg) - } - reqJSON, err := json.Marshal(req) - if err != nil { - return err - } - - reply := s.subscription.Subject - if err := s.consumer.jetStream.conn.PublishRequest(subject, reply, reqJSON); err != nil { - return err - } - return nil -} diff --git a/nats.go b/nats.go index 4a9166f..b656ad9 100644 --- a/nats.go +++ b/nats.go @@ -510,31 +510,32 @@ type Conn struct { mu sync.RWMutex // Opts holds the configuration of the Conn. // Modifying the configuration of a running Conn is a race. - Opts Options - wg sync.WaitGroup - srvPool []*srv - current *srv - urls map[string]struct{} // Keep track of all known URLs (used by processInfo) - conn net.Conn - bw *natsWriter - br *natsReader - fch chan struct{} - info serverInfo - ssid int64 - subsMu sync.RWMutex - subs map[int64]*Subscription - ach *asyncCallbacksHandler - pongs []chan struct{} - scratch [scratchSize]byte - status Status - initc bool // true if the connection is performing the initial connect - err error - ps *parseState - ptmr *time.Timer - pout int - ar bool // abort reconnect - rqch chan struct{} - ws bool // true if a websocket connection + Opts Options + wg sync.WaitGroup + srvPool []*srv + current *srv + urls map[string]struct{} // Keep track of all known URLs (used by processInfo) + conn net.Conn + bw *natsWriter + br *natsReader + fch chan struct{} + info serverInfo + ssid int64 + subsMu sync.RWMutex + subs map[int64]*Subscription + ach *asyncCallbacksHandler + pongs []chan struct{} + scratch [scratchSize]byte + status Status + statListeners map[Status][]chan struct{} + initc bool // true if the connection is performing the initial connect + err error + ps *parseState + ptmr *time.Timer + pout int + ar bool // abort reconnect + rqch chan struct{} + ws bool // true if a websocket connection // New style response handler respSub string // The wildcard subject @@ -2181,7 +2182,7 @@ func (nc *Conn) processConnectInit() error { defer nc.conn.SetDeadline(time.Time{}) // Set our status to connecting. - nc.status = CONNECTING + nc.changeConnStatus(CONNECTING) // Process the INFO protocol received from the server err := nc.processExpectedInfo() @@ -2273,7 +2274,7 @@ func (nc *Conn) connect() (bool, error) { nc.initc = false } else if nc.Opts.RetryOnFailedConnect { nc.setup() - nc.status = RECONNECTING + nc.changeConnStatus(RECONNECTING) nc.bw.switchToPending() go nc.doReconnect(ErrNoServers) err = nil @@ -2507,7 +2508,7 @@ func (nc *Conn) sendConnect() error { } // This is where we are truly connected. - nc.status = CONNECTED + nc.changeConnStatus(CONNECTED) return nil } @@ -2682,7 +2683,7 @@ func (nc *Conn) doReconnect(err error) { if nc.ar { break } - nc.status = RECONNECTING + nc.changeConnStatus(RECONNECTING) continue } @@ -2700,7 +2701,7 @@ func (nc *Conn) doReconnect(err error) { // Now send off and clear pending buffer nc.err = nc.flushReconnectPendingItems() if nc.err != nil { - nc.status = RECONNECTING + nc.changeConnStatus(RECONNECTING) // Stop the ping timer (if set) nc.stopPingTimer() // Since processConnectInit() returned without error, the @@ -2714,7 +2715,7 @@ func (nc *Conn) doReconnect(err error) { nc.bw.doneWithPending() // This is where we are truly connected. - nc.status = CONNECTED + nc.changeConnStatus(CONNECTED) // If we are here with a retry on failed connect, indicate that the // initial connect is now complete. @@ -2753,7 +2754,7 @@ func (nc *Conn) processOpErr(err error) { if nc.Opts.AllowReconnect && nc.status == CONNECTED { // Set our new status - nc.status = RECONNECTING + nc.changeConnStatus(RECONNECTING) // Stop ping timer if set nc.stopPingTimer() if nc.conn != nil { @@ -2772,7 +2773,7 @@ func (nc *Conn) processOpErr(err error) { return } - nc.status = DISCONNECTED + nc.changeConnStatus(DISCONNECTED) nc.err = err nc.mu.Unlock() nc.close(CLOSED, true, nil) @@ -4958,11 +4959,11 @@ func (nc *Conn) clearPendingRequestCalls() { func (nc *Conn) close(status Status, doCBs bool, err error) { nc.mu.Lock() if nc.isClosed() { - nc.status = status + nc.changeConnStatus(status) nc.mu.Unlock() return } - nc.status = CLOSED + nc.changeConnStatus(CLOSED) // Kick the Go routines so they fall out. nc.kickFlusher() @@ -5021,7 +5022,7 @@ func (nc *Conn) close(status Status, doCBs bool, err error) { nc.subs = nil nc.subsMu.Unlock() - nc.status = status + nc.changeConnStatus(status) // Perform appropriate callback if needed for a disconnect. if doCBs { @@ -5166,7 +5167,7 @@ func (nc *Conn) drainConnection() { // Flip State nc.mu.Lock() - nc.status = DRAINING_PUBS + nc.changeConnStatus(DRAINING_PUBS) nc.mu.Unlock() // Do publish drain via Flush() call. @@ -5201,7 +5202,7 @@ func (nc *Conn) Drain() error { nc.mu.Unlock() return nil } - nc.status = DRAINING_SUBS + nc.changeConnStatus(DRAINING_SUBS) go nc.drainConnection() nc.mu.Unlock() @@ -5411,6 +5412,52 @@ func (nc *Conn) GetClientID() (uint64, error) { return nc.info.CID, nil } +func (nc *Conn) RegisterStatusChangeListener(status Status, ch chan struct{}) { + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.statListeners == nil { + nc.statListeners = make(map[Status][]chan struct{}) + } + if _, ok := nc.statListeners[status]; !ok { + nc.statListeners[status] = make([]chan struct{}, 0) + } + nc.statListeners[status] = append(nc.statListeners[status], ch) +} + +// sendStatusEvent sends connection status event to all channels. +// If channel is closed, or there is no listener, sendStatusEvent +// will not block. Lock should be held entering. +func (nc *Conn) sendStatusEvent(s Status) { +Loop: + for i := 0; i < len(nc.statListeners[s]); i++ { + // make sure channel is not closed + select { + case <-nc.statListeners[s][i]: + // if chan is closed, remove it + nc.statListeners[s][i] = nc.statListeners[s][len(nc.statListeners[s])-1] + nc.statListeners[s] = nc.statListeners[s][:len(nc.statListeners[s])-1] + i-- + continue Loop + default: + } + // only send event if someone's listening + select { + case nc.statListeners[s][i] <- struct{}{}: + default: + } + } +} + +// changeConnStatus changes connections status and sends events +// to all listeners. Lock should be held entering. +func (nc *Conn) changeConnStatus(status Status) { + if nc == nil { + return + } + nc.sendStatusEvent(status) + nc.status = status +} + // NkeyOptionFromSeed will load an nkey pair from a seed file. // It will return the NKey Option and will handle // signing of nonce challenges from the server. It will take