update: HTTP API

This commit is contained in:
Oarkflow
2025-03-30 16:55:32 +05:45
parent ba75adc7d6
commit 31a9fb8ba7
11 changed files with 340 additions and 85 deletions

View File

@@ -3,8 +3,10 @@ package mq
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"net/http"
"strings" "strings"
"time" "time"
@@ -253,6 +255,14 @@ func (c *Consumer) Consume(ctx context.Context) error {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
} }
c.pool.Start(c.opts.numOfWorkers) c.pool.Start(c.opts.numOfWorkers)
if c.opts.enableHTTPApi {
go func() {
_, err := c.StartHTTPAPI()
if err != nil {
log.Println(fmt.Sprintf("Error on running HTTP API %s", err.Error()))
}
}()
}
// Infinite loop to continuously read messages and reconnect if needed. // Infinite loop to continuously read messages and reconnect if needed.
for { for {
select { select {
@@ -341,3 +351,134 @@ func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
func (c *Consumer) Conn() net.Conn { func (c *Consumer) Conn() net.Conn {
return c.conn return c.conn
} }
// StartHTTPAPI starts an HTTP server on a random available port and registers API endpoints.
// It returns the port number the server is listening on.
func (c *Consumer) StartHTTPAPI() (int, error) {
// Listen on a random port.
ln, err := net.Listen("tcp", ":0")
if err != nil {
return 0, fmt.Errorf("failed to start listener: %w", err)
}
port := ln.Addr().(*net.TCPAddr).Port
// Create a new HTTP mux and register endpoints.
mux := http.NewServeMux()
mux.HandleFunc("/stats", c.handleStats)
mux.HandleFunc("/update", c.handleUpdate)
mux.HandleFunc("/pause", c.handlePause)
mux.HandleFunc("/resume", c.handleResume)
mux.HandleFunc("/stop", c.handleStop)
// Start the server in a new goroutine.
go func() {
// Log errors if the HTTP server stops.
if err := http.Serve(ln, mux); err != nil {
log.Printf("HTTP server error on port %d: %v", port, err)
}
}()
log.Printf("HTTP API for consumer %s started on port %d", c.id, port)
return port, nil
}
// handleStats responds with JSON containing consumer and pool metrics.
func (c *Consumer) handleStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// Gather consumer and pool stats using formatted metrics.
stats := map[string]interface{}{
"consumer_id": c.id,
"queue": c.queue,
"pool_metrics": c.pool.FormattedMetrics(),
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(stats); err != nil {
http.Error(w, fmt.Sprintf("failed to encode stats: %v", err), http.StatusInternalServerError)
}
}
// handleUpdate accepts a POST request with a JSON payload to update the consumer's pool configuration.
// It reuses the consumer's Update method which updates the pool configuration.
func (c *Consumer) handleUpdate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// Read the request body.
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusBadRequest)
return
}
defer r.Body.Close()
// Call the Update method on the consumer (which in turn updates the pool configuration).
if err := c.Update(r.Context(), body); err != nil {
http.Error(w, fmt.Sprintf("failed to update configuration: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "configuration updated"}
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
// handlePause pauses the consumer's pool.
func (c *Consumer) handlePause(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
if err := c.Pause(r.Context()); err != nil {
http.Error(w, fmt.Sprintf("failed to pause consumer: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "consumer paused"}
json.NewEncoder(w).Encode(resp)
}
// handleResume resumes the consumer's pool.
func (c *Consumer) handleResume(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
if err := c.Resume(r.Context()); err != nil {
http.Error(w, fmt.Sprintf("failed to resume consumer: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "consumer resumed"}
json.NewEncoder(w).Encode(resp)
}
// handleStop stops the consumer's pool.
func (c *Consumer) handleStop(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// Stop the consumer.
if err := c.Stop(r.Context()); err != nil {
http.Error(w, fmt.Sprintf("failed to stop consumer: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
resp := map[string]string{"status": "consumer stopped"}
json.NewEncoder(w).Encode(resp)
}

View File

@@ -2,9 +2,10 @@ package dag
import ( import (
"context" "context"
"log"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/consts"
"log"
) )
func (tm *DAG) Consume(ctx context.Context) error { func (tm *DAG) Consume(ctx context.Context) error {
@@ -16,7 +17,7 @@ func (tm *DAG) Consume(ctx context.Context) error {
} }
func (tm *DAG) AssignTopic(topic string) { func (tm *DAG) AssignTopic(topic string) {
tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL())) tm.consumer = mq.NewConsumer(topic, topic, tm.ProcessTask, mq.WithRespondPendingResult(false), mq.WithBrokerURL(tm.server.URL()), mq.WithHTTPApi(tm.server.Options().HTTPApi()))
tm.consumerTopic = topic tm.consumerTopic = topic
} }

View File

@@ -10,6 +10,6 @@ import (
func main() { func main() {
n := &tasks.Node6{} n := &tasks.Node6{}
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithWorkerPool(100, 4, 50000)) consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithBrokerURL(":8081"), mq.WithHTTPApi(true), mq.WithWorkerPool(100, 4, 50000))
consumer1.Consume(context.Background()) consumer1.Consume(context.Background())
} }

View File

@@ -12,7 +12,7 @@ func main() {
task := mq.Task{ task := mq.Task{
Payload: payload, Payload: payload,
} }
publisher := mq.NewPublisher("publish-1") publisher := mq.NewPublisher("publish-1", mq.WithBrokerURL(":8081"))
for i := 0; i < 10000000; i++ { for i := 0; i < 10000000; i++ {
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) // publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
err := publisher.Publish(context.Background(), task, "queue1") err := publisher.Publish(context.Background(), task, "queue1")

View File

@@ -3,13 +3,13 @@ package main
import ( import (
"context" "context"
mq2 "github.com/oarkflow/mq" "github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/examples/tasks"
) )
func main() { func main() {
b := mq2.NewBroker(mq2.WithCallback(tasks.Callback)) b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithBrokerURL(":8081"))
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) // b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
b.NewQueue("queue1") b.NewQueue("queue1")
b.NewQueue("queue2") b.NewQueue("queue2")

View File

@@ -162,7 +162,7 @@ func notify(taskID string, result mq.Result) {
} }
func main() { func main() {
flow := dag.NewDAG("Sample DAG", "sample-dag", notify, mq.WithBrokerURL(":8083")) flow := dag.NewDAG("Sample DAG", "sample-dag", notify, mq.WithBrokerURL(":8083"), mq.WithHTTPApi(true))
flow.AddNode(dag.Page, "Form", "Form", &Form{}) flow.AddNode(dag.Page, "Form", "Form", &Form{})
flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{}) flow.AddNode(dag.Function, "NodeA", "NodeA", &NodeA{})
flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{}) flow.AddNode(dag.Function, "NodeB", "NodeB", &NodeB{})

76
mq.go
View File

@@ -7,6 +7,7 @@ import (
"log" "log"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/oarkflow/errors" "github.com/oarkflow/errors"
@@ -120,33 +121,84 @@ type TLSConfig struct {
UseTLS bool UseTLS bool
} }
// NEW: RateLimiter implementation // RateLimiter implementation
type RateLimiter struct { type RateLimiter struct {
C chan struct{} mu sync.Mutex
C chan struct{}
ticker *time.Ticker
rate int
burst int
stop chan struct{}
} }
// Modified RateLimiter: use blocking send to avoid discarding tokens. // NewRateLimiter creates a new RateLimiter with the specified rate and burst.
func NewRateLimiter(rate int, burst int) *RateLimiter { func NewRateLimiter(rate int, burst int) *RateLimiter {
rl := &RateLimiter{C: make(chan struct{}, burst)} rl := &RateLimiter{
ticker := time.NewTicker(time.Second / time.Duration(rate)) C: make(chan struct{}, burst),
go func() { rate: rate,
for range ticker.C { burst: burst,
rl.C <- struct{}{} // blocking send; tokens queue for deferred task processing stop: make(chan struct{}),
} }
}() rl.ticker = time.NewTicker(time.Second / time.Duration(rate))
go rl.run()
return rl return rl
} }
// run is the internal goroutine that periodically sends tokens.
func (rl *RateLimiter) run() {
for {
select {
case <-rl.ticker.C:
// Blocking send to ensure token accumulation doesn't discard tokens.
rl.mu.Lock()
// Try sending token, but don't block if channel is full.
select {
case rl.C <- struct{}{}:
default:
}
rl.mu.Unlock()
case <-rl.stop:
return
}
}
}
// Wait blocks until a token is available.
func (rl *RateLimiter) Wait() { func (rl *RateLimiter) Wait() {
<-rl.C <-rl.C
} }
// Update allows dynamic adjustment of rate and burst at runtime.
// It immediately applies the new settings.
func (rl *RateLimiter) Update(newRate, newBurst int) {
rl.mu.Lock()
defer rl.mu.Unlock()
// Stop the old ticker.
rl.ticker.Stop()
// Replace the channel with a new one of the new burst capacity.
rl.C = make(chan struct{}, newBurst)
// Update internal state.
rl.rate = newRate
rl.burst = newBurst
// Start a new ticker with the updated rate.
rl.ticker = time.NewTicker(time.Second / time.Duration(newRate))
// The run goroutine will pick up tokens from the new ticker and use the new channel.
}
// Stop terminates the rate limiter's internal goroutine.
func (rl *RateLimiter) Stop() {
close(rl.stop)
rl.ticker.Stop()
}
type Options struct { type Options struct {
storage TaskStorage storage TaskStorage
consumerOnSubscribe func(ctx context.Context, topic, consumerName string) consumerOnSubscribe func(ctx context.Context, topic, consumerName string)
consumerOnClose func(ctx context.Context, topic, consumerName string) consumerOnClose func(ctx context.Context, topic, consumerName string)
notifyResponse func(context.Context, Result) error notifyResponse func(context.Context, Result) error
brokerAddr string brokerAddr string
enableHTTPApi bool
tlsConfig TLSConfig tlsConfig TLSConfig
callback []func(context.Context, Result) Result callback []func(context.Context, Result) Result
queueSize int queueSize int
@@ -197,6 +249,10 @@ func (o *Options) BrokerAddr() string {
return o.brokerAddr return o.brokerAddr
} }
func (o *Options) HTTPApi() bool {
return o.enableHTTPApi
}
func HeadersWithConsumerID(ctx context.Context, id string) map[string]string { func HeadersWithConsumerID(ctx context.Context, id string) map[string]string {
return WithHeaders(ctx, map[string]string{consts.ConsumerKey: id, consts.ContentType: consts.TypeJson}) return WithHeaders(ctx, map[string]string{consts.ConsumerKey: id, consts.ContentType: consts.TypeJson})
} }

View File

@@ -2,10 +2,12 @@ package mq
import ( import (
"context" "context"
"fmt"
"runtime" "runtime"
"time" "time"
"github.com/oarkflow/mq/logger" "github.com/oarkflow/mq/logger"
"github.com/oarkflow/mq/utils"
) )
type ThresholdConfig struct { type ThresholdConfig struct {
@@ -57,12 +59,6 @@ func WithBatchSize(batchSize int) PoolOption {
} }
} }
func WithHealthServicePort(port int) PoolOption {
return func(p *Pool) {
p.port = port
}
}
func WithHandler(handler Handler) PoolOption { func WithHandler(handler Handler) PoolOption {
return func(p *Pool) { return func(p *Pool) {
p.handler = handler p.handler = handler
@@ -117,9 +113,22 @@ func WithPlugin(plugin Plugin) PoolOption {
} }
} }
var BrokerAddr string
func init() {
if BrokerAddr == "" {
port, err := utils.GetRandomPort()
if err != nil {
BrokerAddr = ":8081"
} else {
BrokerAddr = fmt.Sprintf(":%d", port)
}
}
}
func defaultOptions() *Options { func defaultOptions() *Options {
return &Options{ return &Options{
brokerAddr: ":8081", brokerAddr: BrokerAddr,
maxRetries: 5, maxRetries: 5,
respondPendingResult: true, respondPendingResult: true,
initialDelay: 2 * time.Second, initialDelay: 2 * time.Second,
@@ -130,8 +139,6 @@ func defaultOptions() *Options {
maxMemoryLoad: 5000000, maxMemoryLoad: 5000000,
storage: NewMemoryTaskStorage(10 * time.Minute), storage: NewMemoryTaskStorage(10 * time.Minute),
logger: logger.NewDefaultLogger(), logger: logger.NewDefaultLogger(),
BrokerRateLimiter: NewRateLimiter(10, 5),
ConsumerRateLimiter: NewRateLimiter(10, 5),
} }
} }
@@ -194,6 +201,13 @@ func WithTLS(enableTLS bool, certPath, keyPath string) Option {
} }
} }
// WithHTTPApi - Option to enable/disable TLS
func WithHTTPApi(flag bool) Option {
return func(o *Options) {
o.enableHTTPApi = flag
}
}
// WithCAPath - Option to enable/disable TLS // WithCAPath - Option to enable/disable TLS
func WithCAPath(caPath string) Option { func WithCAPath(caPath string) Option {
return func(o *Options) { return func(o *Options) {

132
pool.go
View File

@@ -6,34 +6,40 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/oarkflow/mq/utils"
"github.com/oarkflow/log" "github.com/oarkflow/log"
"github.com/oarkflow/mq/utils"
) )
// Callback is called when a task processing is completed.
type Callback func(ctx context.Context, result Result) error type Callback func(ctx context.Context, result Result) error
// CompletionCallback is called when the pool completes a graceful shutdown.
type CompletionCallback func() type CompletionCallback func()
// Metrics holds cumulative pool metrics.
type Metrics struct { type Metrics struct {
TotalTasks int64 TotalTasks int64 // total number of tasks processed
CompletedTasks int64 CompletedTasks int64 // number of successfully processed tasks
ErrorCount int64 ErrorCount int64 // number of tasks that resulted in error
TotalMemoryUsed int64 TotalMemoryUsed int64 // current memory used (in bytes) by tasks in flight
TotalScheduled int64 TotalScheduled int64 // number of tasks scheduled
ExecutionTime int64 ExecutionTime int64 // cumulative execution time in milliseconds
CumulativeMemoryUsed int64 // cumulative memory used (sum of all task sizes) in bytes
} }
// Plugin is used to inject custom behavior before or after task processing.
type Plugin interface { type Plugin interface {
Initialize(config interface{}) error Initialize(config interface{}) error
BeforeTask(task *QueueTask) BeforeTask(task *QueueTask)
AfterTask(task *QueueTask, result Result) AfterTask(task *QueueTask, result Result)
} }
// DefaultPlugin is a no-op implementation of Plugin.
type DefaultPlugin struct{} type DefaultPlugin struct{}
func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil } func (dp *DefaultPlugin) Initialize(config interface{}) error { return nil }
@@ -44,6 +50,7 @@ func (dp *DefaultPlugin) AfterTask(task *QueueTask, result Result) {
Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked") Logger.Info().Str("taskID", task.payload.ID).Msg("AfterTask plugin invoked")
} }
// DeadLetterQueue stores tasks that have permanently failed.
type DeadLetterQueue struct { type DeadLetterQueue struct {
tasks []*QueueTask tasks []*QueueTask
mu sync.Mutex mu sync.Mutex
@@ -66,6 +73,7 @@ func (dlq *DeadLetterQueue) Add(task *QueueTask) {
Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue") Logger.Warn().Str("taskID", task.payload.ID).Msg("Task added to Dead Letter Queue")
} }
// InMemoryMetricsRegistry stores metrics in memory.
type InMemoryMetricsRegistry struct { type InMemoryMetricsRegistry struct {
metrics map[string]int64 metrics map[string]int64
mu sync.RWMutex mu sync.RWMutex
@@ -98,11 +106,13 @@ func (m *InMemoryMetricsRegistry) Get(metricName string) interface{} {
return m.metrics[metricName] return m.metrics[metricName]
} }
// WarningThresholds defines thresholds for warnings.
type WarningThresholds struct { type WarningThresholds struct {
HighMemory int64 HighMemory int64 // in bytes
LongExecution time.Duration LongExecution time.Duration // threshold duration
} }
// DynamicConfig holds runtime configuration values.
type DynamicConfig struct { type DynamicConfig struct {
Timeout time.Duration Timeout time.Duration
BatchSize int BatchSize int
@@ -112,7 +122,7 @@ type DynamicConfig struct {
MaxRetries int MaxRetries int
ReloadInterval time.Duration ReloadInterval time.Duration
WarningThreshold WarningThresholds WarningThreshold WarningThresholds
NumberOfWorkers int // <-- new field for worker count NumberOfWorkers int // new field for worker count
} }
var Config = &DynamicConfig{ var Config = &DynamicConfig{
@@ -124,12 +134,13 @@ var Config = &DynamicConfig{
MaxRetries: 3, MaxRetries: 3,
ReloadInterval: 30 * time.Second, ReloadInterval: 30 * time.Second,
WarningThreshold: WarningThresholds{ WarningThreshold: WarningThresholds{
HighMemory: 1 * 1024 * 1024, HighMemory: 1 * 1024 * 1024, // 1 MB
LongExecution: 2 * time.Second, LongExecution: 2 * time.Second,
}, },
NumberOfWorkers: 5, // <-- default worker count NumberOfWorkers: 5, // default worker count
} }
// Pool represents the worker pool processing tasks.
type Pool struct { type Pool struct {
taskStorage TaskStorage taskStorage TaskStorage
scheduler *Scheduler scheduler *Scheduler
@@ -166,9 +177,9 @@ type Pool struct {
circuitBreakerFailureCount int32 circuitBreakerFailureCount int32
gracefulShutdownTimeout time.Duration gracefulShutdownTimeout time.Duration
plugins []Plugin plugins []Plugin
port int
} }
// NewPool creates and starts a new pool with the given number of workers.
func NewPool(numOfWorkers int, opts ...PoolOption) *Pool { func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
pool := &Pool{ pool := &Pool{
stop: make(chan struct{}), stop: make(chan struct{}),
@@ -179,7 +190,6 @@ func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
backoffDuration: Config.BackoffDuration, backoffDuration: Config.BackoffDuration,
maxRetries: Config.MaxRetries, maxRetries: Config.MaxRetries,
logger: Logger, logger: Logger,
port: 1234,
dlq: NewDeadLetterQueue(), dlq: NewDeadLetterQueue(),
metricsRegistry: NewInMemoryMetricsRegistry(), metricsRegistry: NewInMemoryMetricsRegistry(),
diagnosticsEnabled: true, diagnosticsEnabled: true,
@@ -198,7 +208,6 @@ func NewPool(numOfWorkers int, opts ...PoolOption) *Pool {
pool.Start(numOfWorkers) pool.Start(numOfWorkers)
startConfigReloader(pool) startConfigReloader(pool)
go pool.dynamicWorkerScaler() go pool.dynamicWorkerScaler()
go pool.startHealthServer()
return pool return pool
} }
@@ -350,15 +359,21 @@ func (wp *Pool) processNextBatch() {
func (wp *Pool) handleTask(task *QueueTask) { func (wp *Pool) handleTask(task *QueueTask) {
ctx, cancel := context.WithTimeout(task.ctx, wp.timeout) ctx, cancel := context.WithTimeout(task.ctx, wp.timeout)
defer cancel() defer cancel()
// Measure memory usage for the task.
taskSize := int64(utils.SizeOf(task.payload)) taskSize := int64(utils.SizeOf(task.payload))
// Increase current memory usage.
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, taskSize) atomic.AddInt64(&wp.metrics.TotalMemoryUsed, taskSize)
// Increase cumulative memory usage.
atomic.AddInt64(&wp.metrics.CumulativeMemoryUsed, taskSize)
atomic.AddInt64(&wp.metrics.TotalTasks, 1) atomic.AddInt64(&wp.metrics.TotalTasks, 1)
startTime := time.Now() startTime := time.Now()
result := wp.handler(ctx, task.payload) result := wp.handler(ctx, task.payload)
executionTime := time.Since(startTime).Milliseconds() execMs := time.Since(startTime).Milliseconds()
atomic.AddInt64(&wp.metrics.ExecutionTime, executionTime) atomic.AddInt64(&wp.metrics.ExecutionTime, execMs)
if wp.thresholds.LongExecution > 0 && executionTime > wp.thresholds.LongExecution.Milliseconds() {
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", executionTime) if wp.thresholds.LongExecution > 0 && execMs > wp.thresholds.LongExecution.Milliseconds() {
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Exceeded execution time threshold: %d ms", execMs)
} }
if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory { if wp.thresholds.HighMemory > 0 && taskSize > wp.thresholds.HighMemory {
wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize) wp.logger.Warn().Str("taskID", task.payload.ID).Msgf("Memory usage %d exceeded threshold", taskSize)
@@ -383,15 +398,14 @@ func (wp *Pool) handleTask(task *QueueTask) {
} }
} else { } else {
atomic.AddInt64(&wp.metrics.CompletedTasks, 1) atomic.AddInt64(&wp.metrics.CompletedTasks, 1)
// Reset failure count on success if using circuit breaker // Reset failure count on success if using circuit breaker.
if wp.circuitBreaker.Enabled { if wp.circuitBreaker.Enabled {
atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0) atomic.StoreInt32(&wp.circuitBreakerFailureCount, 0)
} }
} }
// Diagnostics logging if enabled
if wp.diagnosticsEnabled { if wp.diagnosticsEnabled {
wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", executionTime) wp.logger.Info().Str("taskID", task.payload.ID).Msgf("Task executed in %d ms", execMs)
} }
if wp.callback != nil { if wp.callback != nil {
if err := wp.callback(ctx, result); err != nil { if err := wp.callback(ctx, result); err != nil {
@@ -400,8 +414,9 @@ func (wp *Pool) handleTask(task *QueueTask) {
} }
} }
_ = wp.taskStorage.DeleteTask(task.payload.ID) _ = wp.taskStorage.DeleteTask(task.payload.ID)
// Reduce current memory usage.
atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize) atomic.AddInt64(&wp.metrics.TotalMemoryUsed, -taskSize)
wp.metricsRegistry.Register("task_execution_time", executionTime) wp.metricsRegistry.Register("task_execution_time", execMs)
} }
func (wp *Pool) backoffAndStore(task *QueueTask) { func (wp *Pool) backoffAndStore(task *QueueTask) {
@@ -582,6 +597,38 @@ func (wp *Pool) Metrics() Metrics {
return wp.metrics return wp.metrics
} }
// FormattedMetrics is a helper struct to present human-readable metrics.
type FormattedMetrics struct {
TotalTasks int64 `json:"total_tasks"`
CompletedTasks int64 `json:"completed_tasks"`
ErrorCount int64 `json:"error_count"`
CurrentMemoryUsed string `json:"current_memory_used"`
CumulativeMemoryUsed string `json:"cumulative_memory_used"`
TotalScheduled int64 `json:"total_scheduled"`
CumulativeExecution string `json:"cumulative_execution"`
AverageExecution string `json:"average_execution"`
}
// FormattedMetrics returns a formatted version of the pool metrics.
func (wp *Pool) FormattedMetrics() FormattedMetrics {
// Update TotalScheduled from the scheduler.
wp.metrics.TotalScheduled = int64(len(wp.scheduler.tasks))
var avgExec time.Duration
if wp.metrics.CompletedTasks > 0 {
avgExec = time.Duration(wp.metrics.ExecutionTime/wp.metrics.CompletedTasks) * time.Millisecond
}
return FormattedMetrics{
TotalTasks: wp.metrics.TotalTasks,
CompletedTasks: wp.metrics.CompletedTasks,
ErrorCount: wp.metrics.ErrorCount,
CurrentMemoryUsed: utils.FormatBytes(wp.metrics.TotalMemoryUsed),
CumulativeMemoryUsed: utils.FormatBytes(wp.metrics.CumulativeMemoryUsed),
TotalScheduled: wp.metrics.TotalScheduled,
CumulativeExecution: (time.Duration(wp.metrics.ExecutionTime) * time.Millisecond).String(),
AverageExecution: avgExec.String(),
}
}
func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler } func (wp *Pool) Scheduler() *Scheduler { return wp.scheduler }
func (wp *Pool) dynamicWorkerScaler() { func (wp *Pool) dynamicWorkerScaler() {
@@ -602,40 +649,7 @@ func (wp *Pool) dynamicWorkerScaler() {
} }
} }
func (wp *Pool) startHealthServer() { // UpdateConfig updates pool configuration via a POOL_UPDATE command.
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
status := "OK"
if wp.gracefulShutdown {
status = "shutting down"
}
_, _ = fmt.Fprintf(w, "status: %s\nworkers: %d\nqueueLength: %d\n",
status, atomic.LoadInt32(&wp.numOfWorkers), len(wp.taskQueue))
})
server := &http.Server{
Addr: ":8080",
Handler: mux,
}
go func() {
wp.logger.Info().Msg("Starting health server on :8080")
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
wp.logger.Error().Err(err).Msg("Health server failed")
}
}()
go func() {
<-wp.stop
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
wp.logger.Error().Err(err).Msg("Health server shutdown failed")
} else {
wp.logger.Info().Msg("Health server shutdown gracefully")
}
}()
}
// New method to update pool configuration via POOL_UPDATE command.
func (wp *Pool) UpdateConfig(newConfig *DynamicConfig) error { func (wp *Pool) UpdateConfig(newConfig *DynamicConfig) error {
if err := validateDynamicConfig(newConfig); err != nil { if err := validateDynamicConfig(newConfig); err != nil {
return err return err

View File

@@ -13,3 +13,17 @@ func ConnectionsEqual(c1, c2 net.Conn) bool {
} }
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2) return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
} }
// GetRandomPort returns a free port chosen by the operating system.
func GetRandomPort() (int, error) {
// Bind to port 0, which instructs the OS to assign an available port.
ln, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer ln.Close()
// Extract the port number from the listener's address.
addr := ln.Addr().(*net.TCPAddr)
return addr.Port, nil
}

View File

@@ -1,6 +1,7 @@
package utils package utils
import ( import (
"fmt"
"unsafe" "unsafe"
) )
@@ -14,3 +15,17 @@ func FromByte(b []byte) string {
p := unsafe.SliceData(b) p := unsafe.SliceData(b)
return unsafe.String(p, len(b)) return unsafe.String(p, len(b))
} }
func FormatBytes(bytes int64) string {
units := []string{"B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
if bytes == 0 {
return fmt.Sprintf("0 B")
}
size := float64(bytes)
unitIndex := 0
for size >= 1024 && unitIndex < len(units)-1 {
size /= 1024
unitIndex++
}
return fmt.Sprintf("%.2f %s", size, units[unitIndex])
}