mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-04 23:52:48 +08:00
update: HTTP API
This commit is contained in:
141
consumer.go
141
consumer.go
@@ -3,8 +3,10 @@ package mq
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -253,6 +255,14 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||
}
|
||||
c.pool.Start(c.opts.numOfWorkers)
|
||||
if c.opts.enableHTTPApi {
|
||||
go func() {
|
||||
_, err := c.StartHTTPAPI()
|
||||
if err != nil {
|
||||
log.Println(fmt.Sprintf("Error on running HTTP API %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Infinite loop to continuously read messages and reconnect if needed.
|
||||
for {
|
||||
select {
|
||||
@@ -341,3 +351,134 @@ func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
||||
func (c *Consumer) Conn() net.Conn {
|
||||
return c.conn
|
||||
}
|
||||
|
||||
// StartHTTPAPI starts an HTTP server on a random available port and registers API endpoints.
|
||||
// It returns the port number the server is listening on.
|
||||
func (c *Consumer) StartHTTPAPI() (int, error) {
|
||||
// Listen on a random port.
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to start listener: %w", err)
|
||||
}
|
||||
port := ln.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// Create a new HTTP mux and register endpoints.
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/stats", c.handleStats)
|
||||
mux.HandleFunc("/update", c.handleUpdate)
|
||||
mux.HandleFunc("/pause", c.handlePause)
|
||||
mux.HandleFunc("/resume", c.handleResume)
|
||||
mux.HandleFunc("/stop", c.handleStop)
|
||||
|
||||
// Start the server in a new goroutine.
|
||||
go func() {
|
||||
// Log errors if the HTTP server stops.
|
||||
if err := http.Serve(ln, mux); err != nil {
|
||||
log.Printf("HTTP server error on port %d: %v", port, err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("HTTP API for consumer %s started on port %d", c.id, port)
|
||||
return port, nil
|
||||
}
|
||||
|
||||
// handleStats responds with JSON containing consumer and pool metrics.
|
||||
func (c *Consumer) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Gather consumer and pool stats using formatted metrics.
|
||||
stats := map[string]interface{}{
|
||||
"consumer_id": c.id,
|
||||
"queue": c.queue,
|
||||
"pool_metrics": c.pool.FormattedMetrics(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(stats); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode stats: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdate accepts a POST request with a JSON payload to update the consumer's pool configuration.
|
||||
// It reuses the consumer's Update method which updates the pool configuration.
|
||||
func (c *Consumer) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Read the request body.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Call the Update method on the consumer (which in turn updates the pool configuration).
|
||||
if err := c.Update(r.Context(), body); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to update configuration: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := map[string]string{"status": "configuration updated"}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePause pauses the consumer's pool.
|
||||
func (c *Consumer) handlePause(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.Pause(r.Context()); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to pause consumer: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := map[string]string{"status": "consumer paused"}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleResume resumes the consumer's pool.
|
||||
func (c *Consumer) handleResume(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.Resume(r.Context()); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to resume consumer: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := map[string]string{"status": "consumer resumed"}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleStop stops the consumer's pool.
|
||||
func (c *Consumer) handleStop(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Stop the consumer.
|
||||
if err := c.Stop(r.Context()); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to stop consumer: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := map[string]string{"status": "consumer stopped"}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
@@ -2,9 +2,10 @@ package dag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"log"
|
||||
)
|
||||
|
||||
func (tm *DAG) Consume(ctx context.Context) error {
|
||||
@@ -16,7 +17,7 @@ func (tm *DAG) Consume(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (tm *DAG) AssignTopic(topic string) {
|
||||
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()))
|
||||
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()), mq.WithHTTPApi(tm.server.Options().HTTPApi()))
|
||||
tm.consumerTopic = topic
|
||||
}
|
||||
|
||||
|
@@ -10,6 +10,6 @@ import (
|
||||
|
||||
func main() {
|
||||
n := &tasks.Node6{}
|
||||
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
|
||||
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithBrokerURL(":8081"), mq.WithHTTPApi(true), mq.WithWorkerPool(100, 4, 50000))
|
||||
consumer1.Consume(context.Background())
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ func main() {
|
||||
task := mq.Task{
|
||||
Payload: payload,
|
||||
}
|
||||
publisher := mq.NewPublisher("publish-1")
|
||||
publisher := mq.NewPublisher("publish-1", mq.WithBrokerURL(":8081"))
|
||||
for i := 0; i < 10000000; i++ {
|
||||
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
||||
err := publisher.Publish(context.Background(), task, "queue1")
|
||||
|
@@ -3,13 +3,13 @@ package main
|
||||
import (
|
||||
"context"
|
||||
|
||||
mq2 "github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq"
|
||||
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
)
|
||||
|
||||
func main() {
|
||||
b := mq2.NewBroker(mq2.WithCallback(tasks.Callback))
|
||||
b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithBrokerURL(":8081"))
|
||||
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
||||
b.NewQueue("queue1")
|
||||
b.NewQueue("queue2")
|
||||
|
@@ -162,7 +162,7 @@ func notify(taskID string, result mq.Result) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
flow := dag.NewDAG("Sample DAG", "sample-dag", notify, mq.WithBrokerURL(":8083"))
|
||||
flow := dag.NewDAG("Sample DAG", "sample-dag", notify, mq.WithBrokerURL(":8083"), mq.WithHTTPApi(true))
|
||||
flow.AddNode(dag.Page, "Form", "Form", &Form{})
|
||||
flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{})
|
||||
flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{})
|
||||
|
76
mq.go
76
mq.go
@@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/errors"
|
||||
@@ -120,33 +121,84 @@ type TLSConfig struct {
|
||||
UseTLS bool
|
||||
}
|
||||
|
||||
// NEW: RateLimiter implementation
|
||||
// RateLimiter implementation
|
||||
type RateLimiter struct {
|
||||
C chan struct{}
|
||||
mu sync.Mutex
|
||||
C chan struct{}
|
||||
ticker *time.Ticker
|
||||
rate int
|
||||
burst int
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
// Modified RateLimiter: use blocking send to avoid discarding tokens.
|
||||
// NewRateLimiter creates a new RateLimiter with the specified rate and burst.
|
||||
func NewRateLimiter(rate int, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{C: make(chan struct{}, burst)}
|
||||
ticker := time.NewTicker(time.Second / time.Duration(rate))
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
rl.C <- struct{}{} // blocking send; tokens queue for deferred task processing
|
||||
}
|
||||
}()
|
||||
rl := &RateLimiter{
|
||||
C: make(chan struct{}, burst),
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
rl.ticker = time.NewTicker(time.Second / time.Duration(rate))
|
||||
go rl.run()
|
||||
return rl
|
||||
}
|
||||
|
||||
// run is the internal goroutine that periodically sends tokens.
|
||||
func (rl *RateLimiter) run() {
|
||||
for {
|
||||
select {
|
||||
case <-rl.ticker.C:
|
||||
// Blocking send to ensure token accumulation doesn't discard tokens.
|
||||
rl.mu.Lock()
|
||||
// Try sending token, but don't block if channel is full.
|
||||
select {
|
||||
case rl.C <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
case <-rl.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait blocks until a token is available.
|
||||
func (rl *RateLimiter) Wait() {
|
||||
<-rl.C
|
||||
}
|
||||
|
||||
// Update allows dynamic adjustment of rate and burst at runtime.
|
||||
// It immediately applies the new settings.
|
||||
func (rl *RateLimiter) Update(newRate, newBurst int) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
// Stop the old ticker.
|
||||
rl.ticker.Stop()
|
||||
// Replace the channel with a new one of the new burst capacity.
|
||||
rl.C = make(chan struct{}, newBurst)
|
||||
// Update internal state.
|
||||
rl.rate = newRate
|
||||
rl.burst = newBurst
|
||||
// Start a new ticker with the updated rate.
|
||||
rl.ticker = time.NewTicker(time.Second / time.Duration(newRate))
|
||||
// The run goroutine will pick up tokens from the new ticker and use the new channel.
|
||||
}
|
||||
|
||||
// Stop terminates the rate limiter's internal goroutine.
|
||||
func (rl *RateLimiter) Stop() {
|
||||
close(rl.stop)
|
||||
rl.ticker.Stop()
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
storage TaskStorage
|
||||
consumerOnSubscribe func(ctx context.Context, topic, consumerName string)
|
||||
consumerOnClose func(ctx context.Context, topic, consumerName string)
|
||||
notifyResponse func(context.Context, Result) error
|
||||
brokerAddr string
|
||||
enableHTTPApi bool
|
||||
tlsConfig TLSConfig
|
||||
callback []func(context.Context, Result) Result
|
||||
queueSize int
|
||||
@@ -197,6 +249,10 @@ func (o *Options) BrokerAddr() string {
|
||||
return o.brokerAddr
|
||||
}
|
||||
|
||||
func (o *Options) HTTPApi() bool {
|
||||
return o.enableHTTPApi
|
||||
}
|
||||
|
||||
func HeadersWithConsumerID(ctx context.Context, id string) map[string]string {
|
||||
return WithHeaders(ctx, map[string]string{consts.ConsumerKey: id, consts.ContentType: consts.TypeJson})
|
||||
}
|
||||
|
32
options.go
32
options.go
@@ -2,10 +2,12 @@ package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/logger"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
type ThresholdConfig struct {
|
||||
@@ -57,12 +59,6 @@ func WithBatchSize(batchSize int) PoolOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithHealthServicePort(port int) PoolOption {
|
||||
return func(p *Pool) {
|
||||
p.port = port
|
||||
}
|
||||
}
|
||||
|
||||
func WithHandler(handler Handler) PoolOption {
|
||||
return func(p *Pool) {
|
||||
p.handler = handler
|
||||
@@ -117,9 +113,22 @@ func WithPlugin(plugin Plugin) PoolOption {
|
||||
}
|
||||
}
|
||||
|
||||
var BrokerAddr string
|
||||
|
||||
func init() {
|
||||
if BrokerAddr == "" {
|
||||
port, err := utils.GetRandomPort()
|
||||
if err != nil {
|
||||
BrokerAddr = ":8081"
|
||||
} else {
|
||||
BrokerAddr = fmt.Sprintf(":%d", port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func defaultOptions() *Options {
|
||||
return &Options{
|
||||
brokerAddr: ":8081",
|
||||
brokerAddr: BrokerAddr,
|
||||
maxRetries: 5,
|
||||
respondPendingResult: true,
|
||||
initialDelay: 2 * time.Second,
|
||||
@@ -130,8 +139,6 @@ func defaultOptions() *Options {
|
||||
maxMemoryLoad: 5000000,
|
||||
storage: NewMemoryTaskStorage(10 * time.Minute),
|
||||
logger: logger.NewDefaultLogger(),
|
||||
BrokerRateLimiter: NewRateLimiter(10, 5),
|
||||
ConsumerRateLimiter: NewRateLimiter(10, 5),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,6 +201,13 @@ func WithTLS(enableTLS bool, certPath, keyPath string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPApi - Option to enable/disable TLS
|
||||
func WithHTTPApi(flag bool) Option {
|
||||
return func(o *Options) {
|
||||
o.enableHTTPApi = flag
|
||||
}
|
||||
}
|
||||
|
||||
// WithCAPath - Option to enable/disable TLS
|
||||
func WithCAPath(caPath string) Option {
|
||||
return func(o *Options) {
|
||||
|
132
pool.go
132
pool.go
@@ -6,34 +6,40 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/utils"
|
||||
|
||||
"github.com/oarkflow/log"
|
||||
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
// Callback is called when a task processing is completed.
|
||||
type Callback func(ctx context.Context, result Result) error
|
||||
|
||||
// CompletionCallback is called when the pool completes a graceful shutdown.
|
||||
type CompletionCallback func()
|
||||
|
||||
// Metrics holds cumulative pool metrics.
|
||||
type Metrics struct {
|
||||
TotalTasks int64
|
||||
CompletedTasks int64
|
||||
ErrorCount int64
|
||||
TotalMemoryUsed int64
|
||||
TotalScheduled int64
|
||||
ExecutionTime int64
|
||||
TotalTasks int64 // total number of tasks processed
|
||||
CompletedTasks int64 // number of successfully processed tasks
|
||||
ErrorCount int64 // number of tasks that resulted in error
|
||||
TotalMemoryUsed int64 // current memory used (in bytes) by tasks in flight
|
||||
TotalScheduled int64 // number of tasks scheduled
|
||||
ExecutionTime int64 // cumulative execution time in milliseconds
|
||||
CumulativeMemoryUsed int64 // cumulative memory used (sum of all task sizes) in bytes
|
||||
}
|
||||
|
||||
// Plugin is used to inject custom behavior before or after task processing.
|
||||
type Plugin interface {
|
||||
Initialize(config interface{}) error
|
||||
BeforeTask(task *QueueTask)
|
||||
AfterTask(task *QueueTask, result Result)
|
||||
}
|
||||
|
||||
// DefaultPlugin is a no-op implementation of Plugin.
|
||||
type DefaultPlugin struct{}
|
||||
|
||||
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
|
||||
@@ -44,6 +50,7 @@ func (dp *DefaultPlugin) AfterTask(task *QueueTask, result Result) {
|
||||
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
|
||||
}
|
||||
|
||||
// DeadLetterQueue stores tasks that have permanently failed.
|
||||
type DeadLetterQueue struct {
|
||||
tasks []*QueueTask
|
||||
mu sync.Mutex
|
||||
@@ -66,6 +73,7 @@ func (dlq *DeadLetterQueue) Add(task *QueueTask) {
|
||||
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
|
||||
}
|
||||
|
||||
// InMemoryMetricsRegistry stores metrics in memory.
|
||||
type InMemoryMetricsRegistry struct {
|
||||
metrics map[string]int64
|
||||
mu sync.RWMutex
|
||||
@@ -98,11 +106,13 @@ func (m *InMemoryMetricsRegistry) Get(metricName string) interface{} {
|
||||
return m.metrics[metricName]
|
||||
}
|
||||
|
||||
// WarningThresholds defines thresholds for warnings.
|
||||
type WarningThresholds struct {
|
||||
HighMemory int64
|
||||
LongExecution time.Duration
|
||||
HighMemory int64 // in bytes
|
||||
LongExecution time.Duration // threshold duration
|
||||
}
|
||||
|
||||
// DynamicConfig holds runtime configuration values.
|
||||
type DynamicConfig struct {
|
||||
Timeout time.Duration
|
||||
BatchSize int
|
||||
@@ -112,7 +122,7 @@ type DynamicConfig struct {
|
||||
MaxRetries int
|
||||
ReloadInterval time.Duration
|
||||
WarningThreshold WarningThresholds
|
||||
NumberOfWorkers int // <-- new field for worker count
|
||||
NumberOfWorkers int // new field for worker count
|
||||
}
|
||||
|
||||
var Config = &DynamicConfig{
|
||||
@@ -124,12 +134,13 @@ var Config = &DynamicConfig{
|
||||
MaxRetries: 3,
|
||||
ReloadInterval: 30 * time.Second,
|
||||
WarningThreshold: WarningThresholds{
|
||||
HighMemory: 1 * 1024 * 1024,
|
||||
HighMemory: 1 * 1024 * 1024, // 1 MB
|
||||
LongExecution: 2 * time.Second,
|
||||
},
|
||||
NumberOfWorkers: 5, // <-- default worker count
|
||||
NumberOfWorkers: 5, // default worker count
|
||||
}
|
||||
|
||||
// Pool represents the worker pool processing tasks.
|
||||
type Pool struct {
|
||||
taskStorage TaskStorage
|
||||
scheduler *Scheduler
|
||||
@@ -166,9 +177,9 @@ type Pool struct {
|
||||
circuitBreakerFailureCount int32
|
||||
gracefulShutdownTimeout time.Duration
|
||||
plugins []Plugin
|
||||
port int
|
||||
}
|
||||
|
||||
// NewPool creates and starts a new pool with the given number of workers.
|
||||
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
||||
pool := &Pool{
|
||||
stop: make(chan struct{}),
|
||||
@@ -179,7 +190,6 @@ func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
||||
backoffDuration: Config.BackoffDuration,
|
||||
maxRetries: Config.MaxRetries,
|
||||
logger: Logger,
|
||||
port: 1234,
|
||||
dlq: NewDeadLetterQueue(),
|
||||
metricsRegistry: NewInMemoryMetricsRegistry(),
|
||||
diagnosticsEnabled: true,
|
||||
@@ -198,7 +208,6 @@ func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
|
||||
pool.Start(numOfWorkers)
|
||||
startConfigReloader(pool)
|
||||
go pool.dynamicWorkerScaler()
|
||||
go pool.startHealthServer()
|
||||
return pool
|
||||
}
|
||||
|
||||
@@ -350,15 +359,21 @@ func (wp *Pool) processNextBatch() {
|
||||
func (wp *Pool) handleTask(task *QueueTask) {
|
||||
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Measure memory usage for the task.
|
||||
taskSize := int64(utils.SizeOf(task.payload))
|
||||
// Increase current memory usage.
|
||||
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, taskSize)
|
||||
// Increase cumulative memory usage.
|
||||
atomic.AddInt64(&wp.metrics.CumulativeMemoryUsed, taskSize)
|
||||
atomic.AddInt64(&wp.metrics.TotalTasks, 1)
|
||||
startTime := time.Now()
|
||||
result := wp.handler(ctx, task.payload)
|
||||
executionTime := time.Since(startTime).Milliseconds()
|
||||
atomic.AddInt64(&wp.metrics.ExecutionTime, executionTime)
|
||||
if wp.thresholds.LongExecution > 0 && executionTime > wp.thresholds.LongExecution.Milliseconds() {
|
||||
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", executionTime)
|
||||
execMs := time.Since(startTime).Milliseconds()
|
||||
atomic.AddInt64(&wp.metrics.ExecutionTime, execMs)
|
||||
|
||||
if wp.thresholds.LongExecution > 0 && execMs > wp.thresholds.LongExecution.Milliseconds() {
|
||||
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", execMs)
|
||||
}
|
||||
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
|
||||
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
|
||||
@@ -383,15 +398,14 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
||||
}
|
||||
} else {
|
||||
atomic.AddInt64(&wp.metrics.CompletedTasks, 1)
|
||||
// Reset failure count on success if using circuit breaker
|
||||
// Reset failure count on success if using circuit breaker.
|
||||
if wp.circuitBreaker.Enabled {
|
||||
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Diagnostics logging if enabled
|
||||
if wp.diagnosticsEnabled {
|
||||
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", executionTime)
|
||||
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", execMs)
|
||||
}
|
||||
if wp.callback != nil {
|
||||
if err := wp.callback(ctx, result); err != nil {
|
||||
@@ -400,8 +414,9 @@ func (wp *Pool) handleTask(task *QueueTask) {
|
||||
}
|
||||
}
|
||||
_ = wp.taskStorage.DeleteTask(task.payload.ID)
|
||||
// Reduce current memory usage.
|
||||
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
|
||||
wp.metricsRegistry.Register("task_execution_time", executionTime)
|
||||
wp.metricsRegistry.Register("task_execution_time", execMs)
|
||||
}
|
||||
|
||||
func (wp *Pool) backoffAndStore(task *QueueTask) {
|
||||
@@ -582,6 +597,38 @@ func (wp *Pool) Metrics() Metrics {
|
||||
return wp.metrics
|
||||
}
|
||||
|
||||
// FormattedMetrics is a helper struct to present human-readable metrics.
|
||||
type FormattedMetrics struct {
|
||||
TotalTasks int64 `json:"total_tasks"`
|
||||
CompletedTasks int64 `json:"completed_tasks"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
CurrentMemoryUsed string `json:"current_memory_used"`
|
||||
CumulativeMemoryUsed string `json:"cumulative_memory_used"`
|
||||
TotalScheduled int64 `json:"total_scheduled"`
|
||||
CumulativeExecution string `json:"cumulative_execution"`
|
||||
AverageExecution string `json:"average_execution"`
|
||||
}
|
||||
|
||||
// FormattedMetrics returns a formatted version of the pool metrics.
|
||||
func (wp *Pool) FormattedMetrics() FormattedMetrics {
|
||||
// Update TotalScheduled from the scheduler.
|
||||
wp.metrics.TotalScheduled = int64(len(wp.scheduler.tasks))
|
||||
var avgExec time.Duration
|
||||
if wp.metrics.CompletedTasks > 0 {
|
||||
avgExec = time.Duration(wp.metrics.ExecutionTime/wp.metrics.CompletedTasks) * time.Millisecond
|
||||
}
|
||||
return FormattedMetrics{
|
||||
TotalTasks: wp.metrics.TotalTasks,
|
||||
CompletedTasks: wp.metrics.CompletedTasks,
|
||||
ErrorCount: wp.metrics.ErrorCount,
|
||||
CurrentMemoryUsed: utils.FormatBytes(wp.metrics.TotalMemoryUsed),
|
||||
CumulativeMemoryUsed: utils.FormatBytes(wp.metrics.CumulativeMemoryUsed),
|
||||
TotalScheduled: wp.metrics.TotalScheduled,
|
||||
CumulativeExecution: (time.Duration(wp.metrics.ExecutionTime) * time.Millisecond).String(),
|
||||
AverageExecution: avgExec.String(),
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
|
||||
|
||||
func (wp *Pool) dynamicWorkerScaler() {
|
||||
@@ -602,40 +649,7 @@ func (wp *Pool) dynamicWorkerScaler() {
|
||||
}
|
||||
}
|
||||
|
||||
func (wp *Pool) startHealthServer() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
status := "OK"
|
||||
if wp.gracefulShutdown {
|
||||
status = "shutting down"
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "status: %s\nworkers: %d\nqueueLength: %d\n",
|
||||
status, atomic.LoadInt32(&wp.numOfWorkers), len(wp.taskQueue))
|
||||
})
|
||||
server := &http.Server{
|
||||
Addr: ":8080",
|
||||
Handler: mux,
|
||||
}
|
||||
go func() {
|
||||
wp.logger.Info().Msg("Starting health server on :8080")
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
wp.logger.Error().Err(err).Msg("Health server failed")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-wp.stop
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
wp.logger.Error().Err(err).Msg("Health server shutdown failed")
|
||||
} else {
|
||||
wp.logger.Info().Msg("Health server shutdown gracefully")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// New method to update pool configuration via POOL_UPDATE command.
|
||||
// UpdateConfig updates pool configuration via a POOL_UPDATE command.
|
||||
func (wp *Pool) UpdateConfig(newConfig *DynamicConfig) error {
|
||||
if err := validateDynamicConfig(newConfig); err != nil {
|
||||
return err
|
||||
|
@@ -13,3 +13,17 @@ func ConnectionsEqual(c1, c2 net.Conn) bool {
|
||||
}
|
||||
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
|
||||
}
|
||||
|
||||
// GetRandomPort returns a free port chosen by the operating system.
|
||||
func GetRandomPort() (int, error) {
|
||||
// Bind to port 0, which instructs the OS to assign an available port.
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// Extract the port number from the listener's address.
|
||||
addr := ln.Addr().(*net.TCPAddr)
|
||||
return addr.Port, nil
|
||||
}
|
||||
|
15
utils/str.go
15
utils/str.go
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -14,3 +15,17 @@ func FromByte(b []byte) string {
|
||||
p := unsafe.SliceData(b)
|
||||
return unsafe.String(p, len(b))
|
||||
}
|
||||
|
||||
func FormatBytes(bytes int64) string {
|
||||
units := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
|
||||
if bytes == 0 {
|
||||
return fmt.Sprintf("0 B")
|
||||
}
|
||||
size := float64(bytes)
|
||||
unitIndex := 0
|
||||
for size >= 1024 && unitIndex < len(units)-1 {
|
||||
size /= 1024
|
||||
unitIndex++
|
||||
}
|
||||
return fmt.Sprintf("%.2f %s", size, units[unitIndex])
|
||||
}
|
||||
|
Reference in New Issue
Block a user