mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-04 23:52:48 +08:00
update: HTTP API
This commit is contained in:
141
consumer.go
141
consumer.go
@@ -3,8 +3,10 @@ package mq
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -253,6 +255,14 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
|||||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||||
}
|
}
|
||||||
c.pool.Start(c.opts.numOfWorkers)
|
c.pool.Start(c.opts.numOfWorkers)
|
||||||
|
if c.opts.enableHTTPApi {
|
||||||
|
go func() {
|
||||||
|
_, err := c.StartHTTPAPI()
|
||||||
|
if err != nil {
|
||||||
|
log.Println(fmt.Sprintf("Error on running HTTP API %s", err.Error()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
// Infinite loop to continuously read messages and reconnect if needed.
|
// Infinite loop to continuously read messages and reconnect if needed.
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -341,3 +351,134 @@ func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
|||||||
func (c *Consumer) Conn() net.Conn {
|
func (c *Consumer) Conn() net.Conn {
|
||||||
return c.conn
|
return c.conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartHTTPAPI starts an HTTP server on a random available port and registers API endpoints.
|
||||||
|
// It returns the port number the server is listening on.
|
||||||
|
func (c *Consumer) StartHTTPAPI() (int, error) {
|
||||||
|
// Listen on a random port.
|
||||||
|
ln, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to start listener: %w", err)
|
||||||
|
}
|
||||||
|
port := ln.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
// Create a new HTTP mux and register endpoints.
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/stats", c.handleStats)
|
||||||
|
mux.HandleFunc("/update", c.handleUpdate)
|
||||||
|
mux.HandleFunc("/pause", c.handlePause)
|
||||||
|
mux.HandleFunc("/resume", c.handleResume)
|
||||||
|
mux.HandleFunc("/stop", c.handleStop)
|
||||||
|
|
||||||
|
// Start the server in a new goroutine.
|
||||||
|
go func() {
|
||||||
|
// Log errors if the HTTP server stops.
|
||||||
|
if err := http.Serve(ln, mux); err != nil {
|
||||||
|
log.Printf("HTTP server error on port %d: %v", port, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Printf("HTTP API for consumer %s started on port %d", c.id, port)
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStats responds with JSON containing consumer and pool metrics.
|
||||||
|
func (c *Consumer) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gather consumer and pool stats using formatted metrics.
|
||||||
|
stats := map[string]interface{}{
|
||||||
|
"consumer_id": c.id,
|
||||||
|
"queue": c.queue,
|
||||||
|
"pool_metrics": c.pool.FormattedMetrics(),
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(stats); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode stats: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUpdate accepts a POST request with a JSON payload to update the consumer's pool configuration.
|
||||||
|
// It reuses the consumer's Update method which updates the pool configuration.
|
||||||
|
func (c *Consumer) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the request body.
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
// Call the Update method on the consumer (which in turn updates the pool configuration).
|
||||||
|
if err := c.Update(r.Context(), body); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to update configuration: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
resp := map[string]string{"status": "configuration updated"}
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlePause pauses the consumer's pool.
|
||||||
|
func (c *Consumer) handlePause(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Pause(r.Context()); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to pause consumer: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
resp := map[string]string{"status": "consumer paused"}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleResume resumes the consumer's pool.
|
||||||
|
func (c *Consumer) handleResume(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Resume(r.Context()); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to resume consumer: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
resp := map[string]string{"status": "consumer resumed"}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStop stops the consumer's pool.
|
||||||
|
func (c *Consumer) handleStop(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop the consumer.
|
||||||
|
if err := c.Stop(r.Context()); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to stop consumer: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
resp := map[string]string{"status": "consumer stopped"}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
@@ -2,9 +2,10 @@ package dag
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
"github.com/oarkflow/mq/consts"
|
"github.com/oarkflow/mq/consts"
|
||||||
"log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (tm *DAG) Consume(ctx context.Context) error {
|
func (tm *DAG) Consume(ctx context.Context) error {
|
||||||
@@ -16,7 +17,7 @@ func (tm *DAG) Consume(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tm *DAG) AssignTopic(topic string) {
|
func (tm *DAG) AssignTopic(topic string) {
|
||||||
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
|
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()), mq.WithHTTPApi(tm.server.Options().HTTPApi()))
|
||||||
tm.consumerTopic = topic
|
tm.consumerTopic = topic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -10,6 +10,6 @@ import (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
n := &tasks.Node6{}
|
n := &tasks.Node6{}
|
||||||
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
|
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithBrokerURL(":8081"), mq.WithHTTPApi(true), mq.WithWorkerPool(100, 4, 50000))
|
||||||
consumer1.Consume(context.Background())
|
consumer1.Consume(context.Background())
|
||||||
}
|
}
|
||||||
|
@@ -12,7 +12,7 @@ func main() {
|
|||||||
task := mq.Task{
|
task := mq.Task{
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
publisher := mq.NewPublisher("publish-1")
|
publisher := mq.NewPublisher("publish-1", mq.WithBrokerURL(":8081"))
|
||||||
for i := 0; i < 10000000; i++ {
|
for i := 0; i < 10000000; i++ {
|
||||||
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
||||||
err := publisher.Publish(context.Background(), task, "queue1")
|
err := publisher.Publish(context.Background(), task, "queue1")
|
||||||
|
@@ -3,13 +3,13 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
mq2 "github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/examples/tasks"
|
"github.com/oarkflow/mq/examples/tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
b := mq2.NewBroker(mq2.WithCallback(tasks.Callback))
|
b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithBrokerURL(":8081"))
|
||||||
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
||||||
b.NewQueue("queue1")
|
b.NewQueue("queue1")
|
||||||
b.NewQueue("queue2")
|
b.NewQueue("queue2")
|
||||||
|
@@ -162,7 +162,7 @@ func notify(taskID string, result mq.Result) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flow := dag.NewDAG("Sample DAG", "sample-dag", notify, mq.WithBrokerURL(":8083"))
|
flow := dag.NewDAG("Sample DAG", "sample-dag", notify, mq.WithBrokerURL(":8083"), mq.WithHTTPApi(true))
|
||||||
flow.AddNode(dag.Page, "Form", "Form", &Form{})
|
flow.AddNode(dag.Page, "Form", "Form", &Form{})
|
||||||
flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{})
|
flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{})
|
||||||
flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{})
|
flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{})
|
||||||
|
76
mq.go
76
mq.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/errors"
|
"github.com/oarkflow/errors"
|
||||||
@@ -120,33 +121,84 @@ type TLSConfig struct {
|
|||||||
UseTLS bool
|
UseTLS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NEW: RateLimiter implementation
|
// RateLimiter implementation
|
||||||
type RateLimiter struct {
|
type RateLimiter struct {
|
||||||
C chan struct{}
|
mu sync.Mutex
|
||||||
|
C chan struct{}
|
||||||
|
ticker *time.Ticker
|
||||||
|
rate int
|
||||||
|
burst int
|
||||||
|
stop chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Modified RateLimiter: use blocking send to avoid discarding tokens.
|
// NewRateLimiter creates a new RateLimiter with the specified rate and burst.
|
||||||
func NewRateLimiter(rate int, burst int) *RateLimiter {
|
func NewRateLimiter(rate int, burst int) *RateLimiter {
|
||||||
rl := &RateLimiter{C: make(chan struct{}, burst)}
|
rl := &RateLimiter{
|
||||||
ticker := time.NewTicker(time.Second / time.Duration(rate))
|
C: make(chan struct{}, burst),
|
||||||
go func() {
|
rate: rate,
|
||||||
for range ticker.C {
|
burst: burst,
|
||||||
rl.C <- struct{}{} // blocking send; tokens queue for deferred task processing
|
stop: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}()
|
rl.ticker = time.NewTicker(time.Second / time.Duration(rate))
|
||||||
|
go rl.run()
|
||||||
return rl
|
return rl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run is the internal goroutine that periodically sends tokens.
|
||||||
|
func (rl *RateLimiter) run() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-rl.ticker.C:
|
||||||
|
// Blocking send to ensure token accumulation doesn't discard tokens.
|
||||||
|
rl.mu.Lock()
|
||||||
|
// Try sending token, but don't block if channel is full.
|
||||||
|
select {
|
||||||
|
case rl.C <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
rl.mu.Unlock()
|
||||||
|
case <-rl.stop:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until a token is available.
|
||||||
func (rl *RateLimiter) Wait() {
|
func (rl *RateLimiter) Wait() {
|
||||||
<-rl.C
|
<-rl.C
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update allows dynamic adjustment of rate and burst at runtime.
|
||||||
|
// It immediately applies the new settings.
|
||||||
|
func (rl *RateLimiter) Update(newRate, newBurst int) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
// Stop the old ticker.
|
||||||
|
rl.ticker.Stop()
|
||||||
|
// Replace the channel with a new one of the new burst capacity.
|
||||||
|
rl.C = make(chan struct{}, newBurst)
|
||||||
|
// Update internal state.
|
||||||
|
rl.rate = newRate
|
||||||
|
rl.burst = newBurst
|
||||||
|
// Start a new ticker with the updated rate.
|
||||||
|
rl.ticker = time.NewTicker(time.Second / time.Duration(newRate))
|
||||||
|
// The run goroutine will pick up tokens from the new ticker and use the new channel.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop terminates the rate limiter's internal goroutine.
|
||||||
|
func (rl *RateLimiter) Stop() {
|
||||||
|
close(rl.stop)
|
||||||
|
rl.ticker.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
storage TaskStorage
|
storage TaskStorage
|
||||||
consumerOnSubscribe func(ctx context.Context, topic, consumerName string)
|
consumerOnSubscribe func(ctx context.Context, topic, consumerName string)
|
||||||
consumerOnClose func(ctx context.Context, topic, consumerName string)
|
consumerOnClose func(ctx context.Context, topic, consumerName string)
|
||||||
notifyResponse func(context.Context, Result) error
|
notifyResponse func(context.Context, Result) error
|
||||||
brokerAddr string
|
brokerAddr string
|
||||||
|
enableHTTPApi bool
|
||||||
tlsConfig TLSConfig
|
tlsConfig TLSConfig
|
||||||
callback []func(context.Context, Result) Result
|
callback []func(context.Context, Result) Result
|
||||||
queueSize int
|
queueSize int
|
||||||
@@ -197,6 +249,10 @@ func (o *Options) BrokerAddr() string {
|
|||||||
return o.brokerAddr
|
return o.brokerAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Options) HTTPApi() bool {
|
||||||
|
return o.enableHTTPApi
|
||||||
|
}
|
||||||
|
|
||||||
func HeadersWithConsumerID(ctx context.Context, id string) map[string]string {
|
func HeadersWithConsumerID(ctx context.Context, id string) map[string]string {
|
||||||
return WithHeaders(ctx, map[string]string{consts.ConsumerKey: id, consts.ContentType: consts.TypeJson})
|
return WithHeaders(ctx, map[string]string{consts.ConsumerKey: id, consts.ContentType: consts.TypeJson})
|
||||||
}
|
}
|
||||||
|
32
options.go
32
options.go
@@ -2,10 +2,12 @@ package mq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/logger"
|
"github.com/oarkflow/mq/logger"
|
||||||
|
"github.com/oarkflow/mq/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ThresholdConfig struct {
|
type ThresholdConfig struct {
|
||||||
@@ -57,12 +59,6 @@ func WithBatchSize(batchSize int) PoolOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithHealthServicePort(port int) PoolOption {
|
|
||||||
return func(p *Pool) {
|
|
||||||
p.port = port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithHandler(handler Handler) PoolOption {
|
func WithHandler(handler Handler) PoolOption {
|
||||||
return func(p *Pool) {
|
return func(p *Pool) {
|
||||||
p.handler = handler
|
p.handler = handler
|
||||||
@@ -117,9 +113,22 @@ func WithPlugin(plugin Plugin) PoolOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var BrokerAddr string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if BrokerAddr == "" {
|
||||||
|
port, err := utils.GetRandomPort()
|
||||||
|
if err != nil {
|
||||||
|
BrokerAddr = ":8081"
|
||||||
|
} else {
|
||||||
|
BrokerAddr = fmt.Sprintf(":%d", port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func defaultOptions() *Options {
|
func defaultOptions() *Options {
|
||||||
return &Options{
|
return &Options{
|
||||||
brokerAddr: ":8081",
|
brokerAddr: BrokerAddr,
|
||||||
maxRetries: 5,
|
maxRetries: 5,
|
||||||
respondPendingResult: true,
|
respondPendingResult: true,
|
||||||
initialDelay: 2 * time.Second,
|
initialDelay: 2 * time.Second,
|
||||||
@@ -130,8 +139,6 @@ func defaultOptions() *Options {
|
|||||||
maxMemoryLoad: 5000000,
|
maxMemoryLoad: 5000000,
|
||||||
storage: NewMemoryTaskStorage(10 * time.Minute),
|
storage: NewMemoryTaskStorage(10 * time.Minute),
|
||||||
logger: logger.NewDefaultLogger(),
|
logger: logger.NewDefaultLogger(),
|
||||||
BrokerRateLimiter: NewRateLimiter(10, 5),
|
|
||||||
ConsumerRateLimiter: NewRateLimiter(10, 5),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,6 +201,13 @@ func WithTLS(enableTLS bool, certPath, keyPath string) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithHTTPApi - Option to enable/disable TLS
|
||||||
|
func WithHTTPApi(flag bool) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.enableHTTPApi = flag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithCAPath - Option to enable/disable TLS
|
// WithCAPath - Option to enable/disable TLS
|
||||||
func WithCAPath(caPath string) Option {
|
func WithCAPath(caPath string) Option {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
|
132
pool.go
132
pool.go
@@ -6,34 +6,40 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/utils"
|
|
||||||
|
|
||||||
"github.com/oarkflow/log"
|
"github.com/oarkflow/log"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Callback is called when a task processing is completed.
|
||||||
type Callback func(ctx context.Context, result Result) error
|
type Callback func(ctx context.Context, result Result) error
|
||||||
|
|
||||||
|
// CompletionCallback is called when the pool completes a graceful shutdown.
|
||||||
type CompletionCallback func()
|
type CompletionCallback func()
|
||||||
|
|
||||||
|
// Metrics holds cumulative pool metrics.
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalTasks int64
|
TotalTasks int64 // total number of tasks processed
|
||||||
CompletedTasks int64
|
CompletedTasks int64 // number of successfully processed tasks
|
||||||
ErrorCount int64
|
ErrorCount int64 // number of tasks that resulted in error
|
||||||
TotalMemoryUsed int64
|
TotalMemoryUsed int64 // current memory used (in bytes) by tasks in flight
|
||||||
TotalScheduled int64
|
TotalScheduled int64 // number of tasks scheduled
|
||||||
ExecutionTime int64
|
ExecutionTime int64 // cumulative execution time in milliseconds
|
||||||
|
CumulativeMemoryUsed int64 // cumulative memory used (sum of all task sizes) in bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Plugin is used to inject custom behavior before or after task processing.
|
||||||
type Plugin interface {
|
type Plugin interface {
|
||||||
Initialize(config interface{}) error
|
Initialize(config interface{}) error
|
||||||
BeforeTask(task *QueueTask)
|
BeforeTask(task *QueueTask)
|
||||||
AfterTask(task *QueueTask, result Result)
|
AfterTask(task *QueueTask, result Result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultPlugin is a no-op implementation of Plugin.
|
||||||
type DefaultPlugin struct{}
|
type DefaultPlugin struct{}
|
||||||
|
|
||||||
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
|
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
|
||||||
@@ -44,6 +50,7 @@ func (dp *DefaultPlugin) AfterTask(task *QueueTask, result Result) {
|
|||||||
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
|
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeadLetterQueue stores tasks that have permanently failed.
|
||||||
type DeadLetterQueue struct {
|
type DeadLetterQueue struct {
|
||||||
tasks []*QueueTask
|
tasks []*QueueTask
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -66,6 +73,7 @@ func (dlq *DeadLetterQueue) Add(task *QueueTask) {
|
|||||||
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
|
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InMemoryMetricsRegistry stores metrics in memory.
|
||||||
type InMemoryMetricsRegistry struct {
|
type InMemoryMetricsRegistry struct {
|
||||||
metrics map[string]int64
|
metrics map[string]int64
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@@ -98,11 +106,13 @@ func (m *InMemoryMetricsRegistry) Get(metricName string) interface{} {
|
|||||||
return m.metrics[metricName]
|
return m.metrics[metricName]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WarningThresholds defines thresholds for warnings.
|
||||||
type WarningThresholds struct {
|
type WarningThresholds struct {
|
||||||
HighMemory int64
|
HighMemory int64 // in bytes
|
||||||
LongExecution time.Duration
|
LongExecution time.Duration // threshold duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DynamicConfig holds runtime configuration values.
|
||||||
type DynamicConfig struct {
|
type DynamicConfig struct {
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
BatchSize int
|
BatchSize int
|
||||||
@@ -112,7 +122,7 @@ type DynamicConfig struct {
|
|||||||
MaxRetries int
|
MaxRetries int
|
||||||
ReloadInterval time.Duration
|
ReloadInterval time.Duration
|
||||||
WarningThreshold WarningThresholds
|
WarningThreshold WarningThresholds
|
||||||
NumberOfWorkers int // <-- new field for worker count
|
NumberOfWorkers int // new field for worker count
|
||||||
}
|
}
|
||||||
|
|
||||||
var Config = &DynamicConfig{
|
var Config = &DynamicConfig{
|
||||||
@@ -124,12 +134,13 @@ var Config = &DynamicConfig{
|
|||||||
MaxRetries: 3,
|
MaxRetries: 3,
|
||||||
ReloadInterval: 30 * time.Second,
|
ReloadInterval: 30 * time.Second,
|
||||||
WarningThreshold: WarningThresholds{
|
WarningThreshold: WarningThresholds{
|
||||||
HighMemory: 1 * 1024 * 1024,
|
HighMemory: 1 * 1024 * 1024, // 1 MB
|
||||||
LongExecution: 2 * time.Second,
|
LongExecution: 2 * time.Second,
|
||||||
},
|
},
|
||||||
NumberOfWorkers: 5, // <-- default worker count
|
NumberOfWorkers: 5, // default worker count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pool represents the worker pool processing tasks.
|
||||||
type Pool struct {
|
type Pool struct {
|
||||||
taskStorage TaskStorage
|
taskStorage TaskStorage
|
||||||
scheduler *Scheduler
|
scheduler *Scheduler
|
||||||
@@ -166,9 +177,9 @@ type Pool struct {
|
|||||||
circuitBreakerFailureCount int32
|
circuitBreakerFailureCount int32
|
||||||
gracefulShutdownTimeout time.Duration
|
gracefulShutdownTimeout time.Duration
|
||||||
plugins []Plugin
|
plugins []Plugin
|
||||||
port int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPool creates and starts a new pool with the given number of workers.
|
||||||
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
||||||
pool := &Pool{
|
pool := &Pool{
|
||||||
stop: make(chan struct{}),
|
stop: make(chan struct{}),
|
||||||
@@ -179,7 +190,6 @@ func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
|||||||
backoffDuration: Config.BackoffDuration,
|
backoffDuration: Config.BackoffDuration,
|
||||||
maxRetries: Config.MaxRetries,
|
maxRetries: Config.MaxRetries,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
port: 1234,
|
|
||||||
dlq: NewDeadLetterQueue(),
|
dlq: NewDeadLetterQueue(),
|
||||||
metricsRegistry: NewInMemoryMetricsRegistry(),
|
metricsRegistry: NewInMemoryMetricsRegistry(),
|
||||||
diagnosticsEnabled: true,
|
diagnosticsEnabled: true,
|
||||||
@@ -198,7 +208,6 @@ func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
|||||||
pool.Start(numOfWorkers)
|
pool.Start(numOfWorkers)
|
||||||
startConfigReloader(pool)
|
startConfigReloader(pool)
|
||||||
go pool.dynamicWorkerScaler()
|
go pool.dynamicWorkerScaler()
|
||||||
go pool.startHealthServer()
|
|
||||||
return pool
|
return pool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,15 +359,21 @@ func (wp *Pool) processNextBatch() {
|
|||||||
func (wp *Pool) handleTask(task *QueueTask) {
|
func (wp *Pool) handleTask(task *QueueTask) {
|
||||||
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
|
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// Measure memory usage for the task.
|
||||||
taskSize := int64(utils.SizeOf(task.payload))
|
taskSize := int64(utils.SizeOf(task.payload))
|
||||||
|
// Increase current memory usage.
|
||||||
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, taskSize)
|
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, taskSize)
|
||||||
|
// Increase cumulative memory usage.
|
||||||
|
atomic.AddInt64(&wp.metrics.CumulativeMemoryUsed, taskSize)
|
||||||
atomic.AddInt64(&wp.metrics.TotalTasks, 1)
|
atomic.AddInt64(&wp.metrics.TotalTasks, 1)
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
result := wp.handler(ctx, task.payload)
|
result := wp.handler(ctx, task.payload)
|
||||||
executionTime := time.Since(startTime).Milliseconds()
|
execMs := time.Since(startTime).Milliseconds()
|
||||||
atomic.AddInt64(&wp.metrics.ExecutionTime, executionTime)
|
atomic.AddInt64(&wp.metrics.ExecutionTime, execMs)
|
||||||
if wp.thresholds.LongExecution > 0 && executionTime > wp.thresholds.LongExecution.Milliseconds() {
|
|
||||||
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", executionTime)
|
if wp.thresholds.LongExecution > 0 && execMs > wp.thresholds.LongExecution.Milliseconds() {
|
||||||
|
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", execMs)
|
||||||
}
|
}
|
||||||
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
|
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
|
||||||
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
|
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
|
||||||
@@ -383,15 +398,14 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
atomic.AddInt64(&wp.metrics.CompletedTasks, 1)
|
atomic.AddInt64(&wp.metrics.CompletedTasks, 1)
|
||||||
// Reset failure count on success if using circuit breaker
|
// Reset failure count on success if using circuit breaker.
|
||||||
if wp.circuitBreaker.Enabled {
|
if wp.circuitBreaker.Enabled {
|
||||||
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
|
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Diagnostics logging if enabled
|
|
||||||
if wp.diagnosticsEnabled {
|
if wp.diagnosticsEnabled {
|
||||||
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", executionTime)
|
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", execMs)
|
||||||
}
|
}
|
||||||
if wp.callback != nil {
|
if wp.callback != nil {
|
||||||
if err := wp.callback(ctx, result); err != nil {
|
if err := wp.callback(ctx, result); err != nil {
|
||||||
@@ -400,8 +414,9 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = wp.taskStorage.DeleteTask(task.payload.ID)
|
_ = wp.taskStorage.DeleteTask(task.payload.ID)
|
||||||
|
// Reduce current memory usage.
|
||||||
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
|
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
|
||||||
wp.metricsRegistry.Register("task_execution_time", executionTime)
|
wp.metricsRegistry.Register("task_execution_time", execMs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wp *Pool) backoffAndStore(task *QueueTask) {
|
func (wp *Pool) backoffAndStore(task *QueueTask) {
|
||||||
@@ -582,6 +597,38 @@ func (wp *Pool) Metrics() Metrics {
|
|||||||
return wp.metrics
|
return wp.metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FormattedMetrics is a helper struct to present human-readable metrics.
|
||||||
|
type FormattedMetrics struct {
|
||||||
|
TotalTasks int64 `json:"total_tasks"`
|
||||||
|
CompletedTasks int64 `json:"completed_tasks"`
|
||||||
|
ErrorCount int64 `json:"error_count"`
|
||||||
|
CurrentMemoryUsed string `json:"current_memory_used"`
|
||||||
|
CumulativeMemoryUsed string `json:"cumulative_memory_used"`
|
||||||
|
TotalScheduled int64 `json:"total_scheduled"`
|
||||||
|
CumulativeExecution string `json:"cumulative_execution"`
|
||||||
|
AverageExecution string `json:"average_execution"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormattedMetrics returns a formatted version of the pool metrics.
|
||||||
|
func (wp *Pool) FormattedMetrics() FormattedMetrics {
|
||||||
|
// Update TotalScheduled from the scheduler.
|
||||||
|
wp.metrics.TotalScheduled = int64(len(wp.scheduler.tasks))
|
||||||
|
var avgExec time.Duration
|
||||||
|
if wp.metrics.CompletedTasks > 0 {
|
||||||
|
avgExec = time.Duration(wp.metrics.ExecutionTime/wp.metrics.CompletedTasks) * time.Millisecond
|
||||||
|
}
|
||||||
|
return FormattedMetrics{
|
||||||
|
TotalTasks: wp.metrics.TotalTasks,
|
||||||
|
CompletedTasks: wp.metrics.CompletedTasks,
|
||||||
|
ErrorCount: wp.metrics.ErrorCount,
|
||||||
|
CurrentMemoryUsed: utils.FormatBytes(wp.metrics.TotalMemoryUsed),
|
||||||
|
CumulativeMemoryUsed: utils.FormatBytes(wp.metrics.CumulativeMemoryUsed),
|
||||||
|
TotalScheduled: wp.metrics.TotalScheduled,
|
||||||
|
CumulativeExecution: (time.Duration(wp.metrics.ExecutionTime) * time.Millisecond).String(),
|
||||||
|
AverageExecution: avgExec.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
|
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
|
||||||
|
|
||||||
func (wp *Pool) dynamicWorkerScaler() {
|
func (wp *Pool) dynamicWorkerScaler() {
|
||||||
@@ -602,40 +649,7 @@ func (wp *Pool) dynamicWorkerScaler() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wp *Pool) startHealthServer() {
|
// UpdateConfig updates pool configuration via a POOL_UPDATE command.
|
||||||
mux := http.NewServeMux()
|
|
||||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
status := "OK"
|
|
||||||
if wp.gracefulShutdown {
|
|
||||||
status = "shutting down"
|
|
||||||
}
|
|
||||||
_, _ = fmt.Fprintf(w, "status: %s\nworkers: %d\nqueueLength: %d\n",
|
|
||||||
status, atomic.LoadInt32(&wp.numOfWorkers), len(wp.taskQueue))
|
|
||||||
})
|
|
||||||
server := &http.Server{
|
|
||||||
Addr: ":8080",
|
|
||||||
Handler: mux,
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
wp.logger.Info().Msg("Starting health server on :8080")
|
|
||||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
wp.logger.Error().Err(err).Msg("Health server failed")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-wp.stop
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := server.Shutdown(ctx); err != nil {
|
|
||||||
wp.logger.Error().Err(err).Msg("Health server shutdown failed")
|
|
||||||
} else {
|
|
||||||
wp.logger.Info().Msg("Health server shutdown gracefully")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// New method to update pool configuration via POOL_UPDATE command.
|
|
||||||
func (wp *Pool) UpdateConfig(newConfig *DynamicConfig) error {
|
func (wp *Pool) UpdateConfig(newConfig *DynamicConfig) error {
|
||||||
if err := validateDynamicConfig(newConfig); err != nil {
|
if err := validateDynamicConfig(newConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@@ -13,3 +13,17 @@ func ConnectionsEqual(c1, c2 net.Conn) bool {
|
|||||||
}
|
}
|
||||||
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
|
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRandomPort returns a free port chosen by the operating system.
|
||||||
|
func GetRandomPort() (int, error) {
|
||||||
|
// Bind to port 0, which instructs the OS to assign an available port.
|
||||||
|
ln, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
// Extract the port number from the listener's address.
|
||||||
|
addr := ln.Addr().(*net.TCPAddr)
|
||||||
|
return addr.Port, nil
|
||||||
|
}
|
||||||
|
15
utils/str.go
15
utils/str.go
@@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,3 +15,17 @@ func FromByte(b []byte) string {
|
|||||||
p := unsafe.SliceData(b)
|
p := unsafe.SliceData(b)
|
||||||
return unsafe.String(p, len(b))
|
return unsafe.String(p, len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FormatBytes(bytes int64) string {
|
||||||
|
units := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
|
||||||
|
if bytes == 0 {
|
||||||
|
return fmt.Sprintf("0 B")
|
||||||
|
}
|
||||||
|
size := float64(bytes)
|
||||||
|
unitIndex := 0
|
||||||
|
for size >= 1024 && unitIndex < len(units)-1 {
|
||||||
|
size /= 1024
|
||||||
|
unitIndex++
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.2f %s", size, units[unitIndex])
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user