feat: use "GetTags" and "SetTags"

This commit is contained in:
sujit
2025-07-31 09:31:28 +05:45
parent d814019d73
commit 103e8f8d88
19 changed files with 1818 additions and 1069 deletions

View File

@@ -1,343 +0,0 @@
// apperror/apperror.go
package apperror
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
)
// APP_ENV values
const (
EnvDevelopment = "development"
EnvStaging = "staging"
EnvProduction = "production"
)
// AppError defines a structured application error
type AppError struct {
Code string `json:"code"` // 9-digit code: XXX|AA|DD|YY
Message string `json:"message"` // human-readable message
StatusCode int `json:"-"` // HTTP status, not serialized
Err error `json:"-"` // wrapped error, not serialized
Metadata map[string]any `json:"metadata,omitempty"` // optional extra info
StackTrace []string `json:"stackTrace,omitempty"`
}
// Error implements error interface
func (e *AppError) Error() string {
if e.Err != nil {
return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Err)
}
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
}
// Unwrap enables errors.Is / errors.As
func (e *AppError) Unwrap() error {
return e.Err
}
// WithMetadata returns a shallow copy with added metadata key/value
func (e *AppError) WithMetadata(key string, val any) *AppError {
newMD := make(map[string]any, len(e.Metadata)+1)
for k, v := range e.Metadata {
newMD[k] = v
}
newMD[key] = val
return &AppError{
Code: e.Code,
Message: e.Message,
StatusCode: e.StatusCode,
Err: e.Err,
Metadata: newMD,
StackTrace: e.StackTrace,
}
}
// GetStackTraceArray returns the error stack trace as an array of strings
func (e *AppError) GetStackTraceArray() []string {
return e.StackTrace
}
// GetStackTraceString returns the error stack trace as a single string
func (e *AppError) GetStackTraceString() string {
return strings.Join(e.StackTrace, "\n")
}
// captureStackTrace returns a slice of strings representing the stack trace.
func captureStackTrace() []string {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
frames := runtime.CallersFrames(pcs[:n])
isDebug := os.Getenv("APP_DEBUG") == "true"
var stack []string
for {
frame, more := frames.Next()
var file string
if !isDebug {
file = "/" + filepath.Base(frame.File)
} else {
file = frame.File
}
if strings.HasSuffix(file, ".go") {
file = strings.TrimSuffix(file, ".go") + ".sec"
}
stack = append(stack, fmt.Sprintf("%s:%d %s", file, frame.Line, frame.Function))
if !more {
break
}
}
return stack
}
// buildCode constructs a 9-digit code: XXX|AA|DD|YY
func buildCode(httpCode, appCode, domainCode, errCode int) string {
return fmt.Sprintf("%03d%02d%02d%02d", httpCode, appCode, domainCode, errCode)
}
// New creates a fresh AppError
func New(httpCode, appCode, domainCode, errCode int, msg string) *AppError {
return &AppError{
Code: buildCode(httpCode, appCode, domainCode, errCode),
Message: msg,
StatusCode: httpCode,
// Prototype: no StackTrace captured at registration time.
}
}
// Modify Wrap to always capture a fresh stack trace.
func Wrap(err error, httpCode, appCode, domainCode, errCode int, msg string) *AppError {
return &AppError{
Code: buildCode(httpCode, appCode, domainCode, errCode),
Message: msg,
StatusCode: httpCode,
Err: err,
StackTrace: captureStackTrace(),
}
}
// New helper: Instance attaches the runtime stack trace to a prototype error.
func Instance(e *AppError) *AppError {
// Create a shallow copy and attach the current stack trace.
copyE := *e
copyE.StackTrace = captureStackTrace()
return &copyE
}
// Modify toAppError to instance a prototype if it lacks a stack trace.
func toAppError(err error) *AppError {
if err == nil {
return nil
}
var ae *AppError
if errors.As(err, &ae) {
if len(ae.StackTrace) == 0 { // Prototype without context.
return Instance(ae)
}
return ae
}
// fallback to internal error 500|00|00|00 with fresh stack trace.
return Wrap(err, http.StatusInternalServerError, 0, 0, 0, "Internal server error")
}
// onError, if set, is called before writing any JSON error
var onError func(*AppError)
func OnError(hook func(*AppError)) {
onError = hook
}
// WriteJSONError writes an error as JSON, includes X-Request-ID, hides details in production
func WriteJSONError(w http.ResponseWriter, r *http.Request, err error) {
appErr := toAppError(err)
// attach request ID
if rid := r.Header.Get("X-Request-ID"); rid != "" {
appErr = appErr.WithMetadata("request_id", rid)
}
// hook
if onError != nil {
onError(appErr)
}
// If no stack trace is present, capture current context stack trace.
if os.Getenv("APP_ENV") != EnvProduction {
appErr.StackTrace = captureStackTrace()
}
fmt.Println(appErr.StackTrace)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(appErr.StatusCode)
resp := map[string]any{
"code": appErr.Code,
"message": appErr.Message,
}
if len(appErr.Metadata) > 0 {
resp["metadata"] = appErr.Metadata
}
if os.Getenv("APP_ENV") != EnvProduction {
resp["stack"] = appErr.StackTrace
}
if appErr.Err != nil {
resp["details"] = appErr.Err.Error()
}
_ = json.NewEncoder(w).Encode(resp)
}
type ErrorRegistry struct {
registry map[string]*AppError
mu sync.RWMutex
}
func (er *ErrorRegistry) Get(name string) (*AppError, bool) {
er.mu.RLock()
defer er.mu.RUnlock()
e, ok := er.registry[name]
return e, ok
}
func (er *ErrorRegistry) Set(name string, e *AppError) {
er.mu.Lock()
defer er.mu.Unlock()
er.registry[name] = e
}
func (er *ErrorRegistry) Delete(name string) {
er.mu.Lock()
defer er.mu.Unlock()
delete(er.registry, name)
}
func (er *ErrorRegistry) List() []*AppError {
er.mu.RLock()
defer er.mu.RUnlock()
out := make([]*AppError, 0, len(er.registry))
for _, e := range er.registry {
// create a shallow copy and remove the StackTrace for listing
copyE := *e
copyE.StackTrace = nil
out = append(out, &copyE)
}
return out
}
func (er *ErrorRegistry) GetByCode(code string) (*AppError, bool) {
er.mu.RLock()
defer er.mu.RUnlock()
for _, e := range er.registry {
if e.Code == code {
return e, true
}
}
return nil, false
}
var (
registry *ErrorRegistry
)
// Register adds a named error; fails if name exists
func Register(name string, e *AppError) error {
if name == "" {
return fmt.Errorf("error name cannot be empty")
}
registry.Set(name, e)
return nil
}
// Update replaces an existing named error; fails if not found
func Update(name string, e *AppError) error {
if name == "" {
return fmt.Errorf("error name cannot be empty")
}
registry.Set(name, e)
return nil
}
// Unregister removes a named error
func Unregister(name string) error {
if name == "" {
return fmt.Errorf("error name cannot be empty")
}
registry.Delete(name)
return nil
}
// Get retrieves a named error
func Get(name string) (*AppError, bool) {
return registry.Get(name)
}
// GetByCode retrieves an error by its 9-digit code
func GetByCode(code string) (*AppError, bool) {
if code == "" {
return nil, false
}
return registry.GetByCode(code)
}
// List returns all registered errors
func List() []*AppError {
return registry.List()
}
// Is/As shortcuts updated to check all registered errors
func Is(err, target error) bool {
if errors.Is(err, target) {
return true
}
registry.mu.RLock()
defer registry.mu.RUnlock()
for _, e := range registry.registry {
if errors.Is(err, e) || errors.Is(e, target) {
return true
}
}
return false
}
func As(err error, target any) bool {
if errors.As(err, target) {
return true
}
registry.mu.RLock()
defer registry.mu.RUnlock()
for _, e := range registry.registry {
if errors.As(err, target) || errors.As(e, target) {
return true
}
}
return false
}
// HTTPMiddleware catches panics and converts to JSON 500
func HTTPMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
p := fmt.Errorf("panic: %v", rec)
WriteJSONError(w, r, Wrap(p, http.StatusInternalServerError, 0, 0, 0, "Internal server error"))
}
}()
next.ServeHTTP(w, r)
})
}
// preload some common errors (with 2-digit app/domain codes)
func init() {
registry = &ErrorRegistry{registry: make(map[string]*AppError)}
_ = Register("ErrNotFound", New(http.StatusNotFound, 1, 1, 1, "Resource not found")) // → "404010101"
_ = Register("ErrInvalidInput", New(http.StatusBadRequest, 1, 1, 2, "Invalid input provided")) // → "400010102"
_ = Register("ErrInternal", New(http.StatusInternalServerError, 1, 1, 0, "Internal server error")) // → "500010100"
_ = Register("ErrUnauthorized", New(http.StatusUnauthorized, 1, 1, 3, "Unauthorized")) // → "401010103"
_ = Register("ErrForbidden", New(http.StatusForbidden, 1, 1, 4, "Forbidden")) // → "403010104"
}

181
codec/README.md Normal file
View File

@@ -0,0 +1,181 @@
# Message Queue Codec
This package provides a robust, production-ready codec implementation for serializing, transmitting, and deserializing messages in a distributed messaging system.
## Features
- **Message Validation**: Comprehensive validation of message format and content
- **Efficient Serialization**: Pluggable serialization with JSON as default
- **Compression**: Optional payload compression for large messages
- **Encryption**: Optional message encryption for sensitive data
- **Large Message Support**: Automatic fragmentation and reassembly of large messages
- **Connection Health**: Heartbeat mechanism for connection monitoring
- **Performance Optimized**: Buffer pooling, efficient memory usage
- **Robust Error Handling**: Detailed error types and error wrapping
- **Timeout Management**: Context-aware deadline handling
- **Observability**: Built-in statistics tracking
## Usage
### Basic Message Sending/Receiving
```go
// Create a message
msg, err := codec.NewMessage(consts.CmdPublish, payload, "my-queue", headers)
if err != nil {
log.Fatalf("Failed to create message: %v", err)
}
// Send the message
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
codec := codec.NewCodec(codec.DefaultConfig())
if err := codec.SendMessage(ctx, conn, msg); err != nil {
log.Fatalf("Failed to send message: %v", err)
}
// Receive a message
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
receivedMsg, err := codec.ReadMessage(ctx, conn)
if err != nil {
log.Fatalf("Failed to receive message: %v", err)
}
```
### Custom Serialization
```go
// Set a custom marshaller/unmarshaller
codec.SetMarshaller(func(v any) ([]byte, error) {
// Custom serialization logic
return someCustomSerializer.Marshal(v)
})
codec.SetUnmarshaller(func(data []byte, v any) error {
// Custom deserialization logic
return someCustomSerializer.Unmarshal(data, v)
})
```
### Enabling Compression
```go
// Enable compression globally
codec.EnableCompression(true)
// Or configure it per codec instance
config := codec.DefaultConfig()
config.EnableCompression = true
codec := codec.NewCodec(config)
```
### Enabling Encryption
```go
// Generate a secure key
key := make([]byte, 32) // 256-bit key
if _, err := rand.Read(key); err != nil {
log.Fatalf("Failed to generate key: %v", err)
}
// Enable encryption globally
codec.EnableEncryption(true, key)
// Or configure it per serialization manager
config := codec.DefaultSerializationConfig()
config.EnableEncryption = true
config.EncryptionKey = key
config.PreferredCipher = "chacha20poly1305" // or "aes-gcm"
```
### Connection Health Monitoring
```go
// Create a heartbeat manager
codec := codec.NewCodec(codec.DefaultConfig())
hm := codec.NewHeartbeatManager(codec, conn)
// Configure heartbeat
hm.SetInterval(15 * time.Second)
hm.SetTimeout(45 * time.Second)
hm.SetOnFailure(func(err error) {
log.Printf("Heartbeat failure: %v", err)
// Take action like closing connection
})
// Start heartbeat monitoring
hm.Start()
defer hm.Stop()
```
## Configuration
The codec behavior can be customized through the `Config` struct:
```go
config := &codec.Config{
MaxMessageSize: 32 * 1024 * 1024, // 32MB max message size
MaxHeaderSize: 64 * 1024, // 64KB max header size
MaxQueueLength: 128, // Max queue name length
ReadTimeout: 15 * time.Second,
WriteTimeout: 10 * time.Second,
EnableCompression: true,
BufferPoolSize: 2000,
}
codec := codec.NewCodec(config)
```
## Error Handling
The codec provides detailed error types for different failure scenarios:
- `ErrMessageTooLarge`: Message exceeds maximum size
- `ErrInvalidMessage`: Invalid message format
- `ErrInvalidQueue`: Invalid queue name
- `ErrInvalidCommand`: Invalid command
- `ErrConnectionClosed`: Connection closed
- `ErrTimeout`: Operation timeout
- `ErrProtocolMismatch`: Protocol version mismatch
- `ErrFragmentationRequired`: Message requires fragmentation
- `ErrInvalidFragment`: Invalid message fragment
- `ErrFragmentTimeout`: Timed out waiting for fragments
- `ErrFragmentMissing`: Missing fragments in sequence
Error handling example:
```go
if err := codec.SendMessage(ctx, conn, msg); err != nil {
if errors.Is(err, codec.ErrMessageTooLarge) {
// Handle message size error
} else if errors.Is(err, codec.ErrTimeout) {
// Handle timeout error
} else {
// Handle other errors
}
}
```
## Testing
The codec package includes testing utilities for validation and performance testing:
```go
ts := codec.NewCodecTestSuite()
// Test basic message sending/receiving
msg, _ := codec.NewMessage(consts.CmdPublish, []byte("test"), "test-queue", nil)
if err := ts.SendReceiveTest(msg); err != nil {
log.Fatalf("Test failed: %v", err)
}
// Test fragmentation/reassembly
largePayload := make([]byte, 20*1024*1024) // 20MB payload
rand.Read(largePayload) // Fill with random data
if err := ts.FragmentationTest(largePayload); err != nil {
log.Fatalf("Fragmentation test failed: %v", err)
}
```

View File

@@ -4,103 +4,492 @@ import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io" // added for full reads
"net"
"sync"
"time" // added for handling deadlines
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/internal/bpool"
)
type Message struct {
Headers map[string]string `msgpack:"h"`
Queue string `msgpack:"q"`
Payload []byte `msgpack:"p"`
Command consts.CMD `msgpack:"c"`
// Protocol version for backward compatibility
const (
ProtocolVersion = uint8(1)
MaxMessageSize = 64 * 1024 * 1024 // 64MB default limit
MaxHeaderSize = 1024 * 1024 // 1MB header limit
MaxQueueLength = 255 // Max queue name length
FragmentationThreshold = 16 * 1024 * 1024 // Messages larger than 16MB will be fragmented
FragmentSize = 8 * 1024 * 1024 // 8MB fragment size
MaxFragments = 256 // Maximum fragments per message
)
// Error definitions
var (
ErrMessageTooLarge = errors.New("message exceeds maximum size")
ErrInvalidMessage = errors.New("invalid message format")
ErrInvalidQueue = errors.New("invalid queue name")
ErrInvalidCommand = errors.New("invalid command")
ErrConnectionClosed = errors.New("connection closed")
ErrTimeout = errors.New("operation timeout")
ErrProtocolMismatch = errors.New("protocol version mismatch")
ErrFragmentationRequired = errors.New("message requires fragmentation")
ErrInvalidFragment = errors.New("invalid message fragment")
ErrFragmentTimeout = errors.New("timed out waiting for fragments")
ErrFragmentMissing = errors.New("missing fragments in sequence")
)
// Config holds codec configuration
type Config struct {
MaxMessageSize uint32
MaxHeaderSize uint32
MaxQueueLength uint8
ReadTimeout time.Duration
WriteTimeout time.Duration
EnableCompression bool
BufferPoolSize int
}
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
// DefaultConfig returns default configuration
func DefaultConfig() *Config {
return &Config{
MaxMessageSize: MaxMessageSize,
MaxHeaderSize: MaxHeaderSize,
MaxQueueLength: MaxQueueLength,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
EnableCompression: false,
BufferPoolSize: 1000,
}
}
// MessageType indicates the type of message being sent
type MessageType uint8
const (
MessageTypeStandard MessageType = iota
MessageTypeFragment
MessageTypeHeartbeat
MessageTypeAck
MessageTypeError
)
// MessageFlag represents various flags that can be set on messages
type MessageFlag uint16
const (
FlagNone MessageFlag = 0
FlagFragmented MessageFlag = 1 << iota
FlagCompressed
FlagEncrypted
FlagHighPriority
FlagRedelivered
FlagNoAck
)
// Message represents a protocol message with validation
type Message struct {
Headers map[string]string `msgpack:"h" json:"headers"`
Queue string `msgpack:"q" json:"queue"`
Payload []byte `msgpack:"p" json:"payload"`
Command consts.CMD `msgpack:"c" json:"command"`
Version uint8 `msgpack:"v" json:"version"`
Timestamp int64 `msgpack:"t" json:"timestamp"`
ID string `msgpack:"i" json:"id,omitempty"`
Flags MessageFlag `msgpack:"f" json:"flags"`
Type MessageType `msgpack:"mt" json:"messageType"`
FragmentID uint32 `msgpack:"fid" json:"fragmentId,omitempty"`
Fragments uint16 `msgpack:"fs" json:"fragments,omitempty"`
Sequence uint16 `msgpack:"seq" json:"sequence,omitempty"`
}
// Codec handles message encoding/decoding with configuration
type Codec struct {
config *Config
mu sync.RWMutex
stats *Stats
}
// Stats tracks codec statistics
type Stats struct {
MessagesSent uint64
MessagesReceived uint64
BytesSent uint64
BytesReceived uint64
Errors uint64
mu sync.RWMutex
}
// NewCodec creates a new codec with configuration
func NewCodec(config *Config) *Codec {
if config == nil {
config = DefaultConfig()
}
return &Codec{
config: config,
stats: &Stats{},
}
}
// NewMessage creates a validated message
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) (*Message, error) {
if err := validateCommand(cmd); err != nil {
return nil, err
}
if err := validateQueue(queue); err != nil {
return nil, err
}
if headers == nil {
headers = make(map[string]string)
}
if err := validateHeaders(headers); err != nil {
return nil, err
}
return &Message{
Headers: headers,
Queue: queue,
Command: cmd,
Payload: payload,
}
Version: ProtocolVersion,
Timestamp: time.Now().Unix(),
}, nil
}
// Validate performs message validation
func (m *Message) Validate(config *Config) error {
if m == nil {
return ErrInvalidMessage
}
if m.Version != ProtocolVersion {
return ErrProtocolMismatch
}
if err := validateCommand(m.Command); err != nil {
return err
}
if err := validateQueue(m.Queue); err != nil {
return err
}
if err := validateHeaders(m.Headers); err != nil {
return err
}
if len(m.Payload) > int(config.MaxMessageSize) {
return ErrMessageTooLarge
}
return nil
}
// Serialize converts message to bytes with validation
func (m *Message) Serialize() ([]byte, error) {
if m == nil {
return nil, ErrInvalidMessage
}
data, err := Marshal(m)
if err != nil {
return nil, err
return nil, fmt.Errorf("serialization failed: %w", err)
}
return data, nil
}
// Deserialize converts bytes to message with validation
func Deserialize(data []byte) (*Message, error) {
if len(data) == 0 {
return nil, ErrInvalidMessage
}
var msg Message
if err := Unmarshal(data, &msg); err != nil {
return nil, err
return nil, fmt.Errorf("deserialization failed: %w", err)
}
return &msg, nil
}
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
// SendMessage sends a message with proper error handling and timeouts
func (c *Codec) SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
// Check context cancellation before proceeding
if err := ctx.Err(); err != nil {
c.incrementErrors()
return fmt.Errorf("context ended before send: %w", err)
}
if msg == nil {
return ErrInvalidMessage
}
// Validate message
if err := msg.Validate(c.config); err != nil {
c.incrementErrors()
return fmt.Errorf("message validation failed: %w", err)
}
// Check if this is a fragment message, if so handle it directly
if msg.Type == MessageTypeFragment {
return c.sendRawMessage(ctx, conn, msg)
}
// Handle fragmentation for large messages if needed
if len(msg.Payload) > int(FragmentationThreshold) && msg.Type != MessageTypeFragment {
fm := NewFragmentManager(c, c.config)
defer fm.Stop()
return fm.sendFragmentedMessage(ctx, conn, msg)
}
// Standard message send path
return c.sendRawMessage(ctx, conn, msg)
}
// sendRawMessage handles the actual sending of a message or fragment
func (c *Codec) sendRawMessage(ctx context.Context, conn net.Conn, msg *Message) error {
// Serialize message
data, err := msg.Serialize()
if err != nil {
return err
c.incrementErrors()
return fmt.Errorf("message serialization failed: %w", err)
}
// Check message size
if len(data) > int(c.config.MaxMessageSize) {
c.incrementErrors()
return ErrMessageTooLarge
}
// Prepare buffer
totalLength := 4 + len(data)
buffer := bpool.Get()
defer bpool.Put(buffer)
buffer.Reset()
if cap(buffer.B) < totalLength {
buffer.B = make([]byte, totalLength)
} else {
buffer.B = buffer.B[:totalLength]
}
// Write length prefix and data
binary.BigEndian.PutUint32(buffer.B[:4], uint32(len(data)))
copy(buffer.B[4:], data)
// Set timeout
deadline := time.Now().Add(c.config.WriteTimeout)
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) {
deadline = ctxDeadline
}
if err := conn.SetWriteDeadline(deadline); err != nil {
c.incrementErrors()
return fmt.Errorf("failed to set write deadline: %w", err)
}
defer conn.SetWriteDeadline(time.Time{})
// Write with buffering
writer := bufio.NewWriter(conn)
// Set write deadline if context has one
if deadline, ok := ctx.Deadline(); ok {
conn.SetWriteDeadline(deadline)
defer conn.SetWriteDeadline(time.Time{})
written, err := writer.Write(buffer.B[:totalLength])
if err != nil {
c.incrementErrors()
return fmt.Errorf("write failed after %d bytes: %w", written, err)
}
// Write full data
if _, err := writer.Write(buffer.B[:totalLength]); err != nil {
return err
if err := writer.Flush(); err != nil {
c.incrementErrors()
return fmt.Errorf("flush failed: %w", err)
}
return writer.Flush()
// Update statistics
c.stats.mu.Lock()
c.stats.MessagesSent++
c.stats.BytesSent += uint64(totalLength)
c.stats.mu.Unlock()
return nil
}
func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
lengthBytes := make([]byte, 4)
// Set read deadline if context has one
if deadline, ok := ctx.Deadline(); ok {
conn.SetReadDeadline(deadline)
// ReadMessage reads a message with proper error handling and timeouts
func (c *Codec) ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
// Check context cancellation before proceeding
if err := ctx.Err(); err != nil {
c.incrementErrors()
return nil, fmt.Errorf("context ended before read: %w", err)
}
// Check context cancellation before proceeding
if err := ctx.Err(); err != nil {
c.incrementErrors()
return nil, fmt.Errorf("context ended before read: %w", err)
}
// Set timeout
deadline := time.Now().Add(c.config.ReadTimeout)
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) {
deadline = ctxDeadline
}
if err := conn.SetReadDeadline(deadline); err != nil {
c.incrementErrors()
return nil, fmt.Errorf("failed to set read deadline: %w", err)
}
defer conn.SetReadDeadline(time.Time{})
}
// Use io.ReadFull to ensure all header bytes are read
// Read length prefix
lengthBytes := make([]byte, 4)
if _, err := io.ReadFull(conn, lengthBytes); err != nil {
return nil, err
c.incrementErrors()
if errors.Is(err, io.EOF) {
return nil, ErrConnectionClosed
}
return nil, fmt.Errorf("failed to read message length: %w", err)
}
length := binary.BigEndian.Uint32(lengthBytes)
// Validate message size
if length > c.config.MaxMessageSize {
c.incrementErrors()
return nil, ErrMessageTooLarge
}
if length == 0 {
c.incrementErrors()
return nil, ErrInvalidMessage
}
// Read message data
buffer := bpool.Get()
defer bpool.Put(buffer)
buffer.Reset()
if cap(buffer.B) < int(length) {
buffer.B = make([]byte, length)
} else {
buffer.B = buffer.B[:length]
}
// Read the entire message payload
if _, err := io.ReadFull(conn, buffer.B[:length]); err != nil {
return nil, err
c.incrementErrors()
if errors.Is(err, io.EOF) {
return nil, ErrConnectionClosed
}
return Deserialize(buffer.B[:length])
return nil, fmt.Errorf("failed to read message data: %w", err)
}
// Deserialize message
msg, err := Deserialize(buffer.B[:length])
if err != nil {
c.incrementErrors()
return nil, fmt.Errorf("failed to deserialize message: %w", err)
}
// Validate message
if err := msg.Validate(c.config); err != nil {
c.incrementErrors()
return nil, fmt.Errorf("message validation failed: %w", err)
}
// Handle message fragments if needed
if msg.Type == MessageTypeFragment || (msg.Flags&FlagFragmented) != 0 {
fm := NewFragmentManager(c, c.config)
reassembled, isFragment, err := fm.processFragment(msg)
if err != nil {
c.incrementErrors()
return nil, fmt.Errorf("fragment processing failed: %w", err)
}
// If this is a fragment but reassembly isn't complete yet
if isFragment && reassembled == nil {
// Update statistics but return nil with no error to indicate
// the caller should continue reading messages
c.stats.mu.Lock()
c.stats.MessagesReceived++
c.stats.BytesReceived += uint64(4 + length)
c.stats.mu.Unlock()
// Read the next fragment
return c.ReadMessage(ctx, conn)
}
// Use the reassembled message if available
if reassembled != nil {
msg = reassembled
}
}
// Update statistics
c.stats.mu.Lock()
c.stats.MessagesReceived++
c.stats.BytesReceived += uint64(4 + length)
c.stats.mu.Unlock()
return msg, nil
}
// GetStats returns codec statistics
func (c *Codec) GetStats() Stats {
c.stats.mu.RLock()
defer c.stats.mu.RUnlock()
return *c.stats
}
// ResetStats resets codec statistics
func (c *Codec) ResetStats() {
c.stats.mu.Lock()
defer c.stats.mu.Unlock()
*c.stats = Stats{}
}
// Helper functions for validation
func validateCommand(cmd consts.CMD) error {
// Add validation based on your command constants
if cmd < 0 {
return ErrInvalidCommand
}
return nil
}
func validateQueue(queue string) error {
if len(queue) == 0 || len(queue) > MaxQueueLength {
return ErrInvalidQueue
}
return nil
}
func validateHeaders(headers map[string]string) error {
totalSize := 0
for k, v := range headers {
totalSize += len(k) + len(v)
if totalSize > MaxHeaderSize {
return ErrMessageTooLarge
}
}
return nil
}
func (c *Codec) incrementErrors() {
c.stats.mu.Lock()
c.stats.Errors++
c.stats.mu.Unlock()
}
// SendMessage Backward compatibility functions
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
codec := NewCodec(DefaultConfig())
return codec.SendMessage(ctx, conn, msg)
}
func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
codec := NewCodec(DefaultConfig())
return codec.ReadMessage(ctx, conn)
}

282
codec/fragment.go Normal file
View File

@@ -0,0 +1,282 @@
package codec
import (
"context"
"encoding/binary"
"fmt"
"hash/crc32"
"net"
"sync"
"time"
"github.com/oarkflow/mq/consts"
)
// FragmentManager handles message fragmentation and reassembly
type FragmentManager struct {
codec *Codec
config *Config
fragmentStore map[string]*fragmentAssembly
mu sync.RWMutex
cleanupInterval time.Duration
cleanupTimer *time.Timer
cleanupChan chan struct{}
reassemblyTimeout time.Duration
}
// fragmentAssembly holds fragments for a specific message
type fragmentAssembly struct {
fragments map[uint16][]byte
totalSize int
createdAt time.Time
totalCount uint16
messageType MessageType
queue string
command consts.CMD
headers map[string]string
id string
}
// NewFragmentManager creates a new fragment manager
func NewFragmentManager(codec *Codec, config *Config) *FragmentManager {
fm := &FragmentManager{
codec: codec,
config: config,
fragmentStore: make(map[string]*fragmentAssembly),
cleanupInterval: time.Minute,
reassemblyTimeout: 5 * time.Minute,
}
fm.startCleanupTimer()
return fm
}
// startCleanupTimer starts a timer to clean up old fragments
func (fm *FragmentManager) startCleanupTimer() {
fm.cleanupChan = make(chan struct{})
fm.cleanupTimer = time.NewTimer(fm.cleanupInterval)
go func() {
for {
select {
case <-fm.cleanupTimer.C:
fm.cleanupExpiredFragments()
fm.cleanupTimer.Reset(fm.cleanupInterval)
case <-fm.cleanupChan:
return
}
}
}()
}
// Stop stops the fragment manager
func (fm *FragmentManager) Stop() {
if fm.cleanupTimer != nil {
fm.cleanupTimer.Stop()
}
if fm.cleanupChan != nil {
close(fm.cleanupChan)
}
}
// cleanupExpiredFragments removes expired fragment assemblies
func (fm *FragmentManager) cleanupExpiredFragments() {
fm.mu.Lock()
defer fm.mu.Unlock()
now := time.Now()
for id, assembly := range fm.fragmentStore {
if now.Sub(assembly.createdAt) > fm.reassemblyTimeout {
delete(fm.fragmentStore, id)
}
}
}
// fragmentMessage splits a large message into fragments
func (fm *FragmentManager) fragmentMessage(ctx context.Context, msg *Message) ([]*Message, error) {
if msg == nil {
return nil, ErrInvalidMessage
}
// Check if fragmentation is needed
if len(msg.Payload) <= int(FragmentationThreshold) {
return []*Message{msg}, nil
}
// Calculate how many fragments we need
fragmentCount := (len(msg.Payload) + int(FragmentSize) - 1) / int(FragmentSize)
if fragmentCount > int(MaxFragments) {
return nil, fmt.Errorf("%w: would require %d fragments, maximum is %d",
ErrMessageTooLarge, fragmentCount, MaxFragments)
}
// Generate a fragment ID using a hash of message content + timestamp
idBytes := []byte(msg.ID)
timestampBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timestampBytes, uint64(msg.Timestamp))
hashInput := append(idBytes, timestampBytes...)
hashInput = append(hashInput, msg.Payload[:min(1024, len(msg.Payload))]...)
fragmentID := crc32.ChecksumIEEE(hashInput)
// Create fragment messages
fragments := make([]*Message, fragmentCount)
for i := 0; i < fragmentCount; i++ {
// Calculate fragment payload boundaries
start := i * int(FragmentSize)
end := min((i+1)*int(FragmentSize), len(msg.Payload))
fragmentPayload := msg.Payload[start:end]
// Create fragment message
fragment := &Message{
Headers: copyHeaders(msg.Headers),
Queue: msg.Queue,
Command: msg.Command,
Version: msg.Version,
Timestamp: msg.Timestamp,
ID: msg.ID,
Type: MessageTypeFragment,
Flags: msg.Flags | FlagFragmented,
FragmentID: fragmentID,
Fragments: uint16(fragmentCount),
Sequence: uint16(i),
Payload: fragmentPayload,
}
fragments[i] = fragment
}
return fragments, nil
}
// sendFragmentedMessage sends a large message as multiple fragments
func (fm *FragmentManager) sendFragmentedMessage(ctx context.Context, conn net.Conn, msg *Message) error {
fragments, err := fm.fragmentMessage(ctx, msg)
if err != nil {
return err
}
// If no fragmentation was needed, send as normal
if len(fragments) == 1 && fragments[0].Type != MessageTypeFragment {
return fm.codec.SendMessage(ctx, conn, fragments[0])
}
// Send each fragment
for _, fragment := range fragments {
if err := fm.codec.SendMessage(ctx, conn, fragment); err != nil {
return fmt.Errorf("failed to send fragment %d/%d: %w",
fragment.Sequence+1, fragment.Fragments, err)
}
}
return nil
}
// processFragment processes a fragment and attempts reassembly
func (fm *FragmentManager) processFragment(msg *Message) (*Message, bool, error) {
if msg == nil || msg.Type != MessageTypeFragment || msg.FragmentID == 0 {
return msg, false, nil // Not a fragment, return as is
}
// Generate a unique key for this fragmented message
key := fmt.Sprintf("%s-%d-%s", msg.ID, msg.FragmentID, msg.Queue)
fm.mu.Lock()
defer fm.mu.Unlock()
// Check if we already have an assembly for this message
assembly, exists := fm.fragmentStore[key]
if !exists {
// Create a new assembly
assembly = &fragmentAssembly{
fragments: make(map[uint16][]byte),
createdAt: time.Now(),
totalCount: msg.Fragments,
messageType: MessageTypeStandard,
queue: msg.Queue,
command: msg.Command,
headers: copyHeaders(msg.Headers),
id: msg.ID,
}
fm.fragmentStore[key] = assembly
}
// Store this fragment
assembly.fragments[msg.Sequence] = msg.Payload
assembly.totalSize += len(msg.Payload)
// Check if we have all fragments
if len(assembly.fragments) == int(assembly.totalCount) {
// We have all fragments, reassemble the message
reassembled, err := fm.reassembleMessage(key, assembly)
if err != nil {
return nil, true, err
}
return reassembled, true, nil
}
// Still waiting for more fragments
return nil, true, nil
}
// reassembleMessage combines fragments into the original message
func (fm *FragmentManager) reassembleMessage(key string, assembly *fragmentAssembly) (*Message, error) {
// Remove the assembly from the store
delete(fm.fragmentStore, key)
// Check if we have all fragments
if len(assembly.fragments) != int(assembly.totalCount) {
return nil, ErrFragmentMissing
}
// Allocate space for the full payload
fullPayload := make([]byte, assembly.totalSize)
// Combine fragments in order
offset := 0
for i := uint16(0); i < assembly.totalCount; i++ {
fragment, exists := assembly.fragments[i]
if !exists {
return nil, fmt.Errorf("%w: missing fragment %d", ErrFragmentMissing, i)
}
copy(fullPayload[offset:], fragment)
offset += len(fragment)
}
// Create the reassembled message
msg := &Message{
Headers: assembly.headers,
Queue: assembly.queue,
Command: assembly.command,
Version: ProtocolVersion,
Timestamp: time.Now().Unix(),
ID: assembly.id,
Type: assembly.messageType,
Flags: FlagNone, // Clear fragmentation flag
Payload: fullPayload,
}
return msg, nil
}
// Helper function to make a copy of headers to avoid shared references
func copyHeaders(src map[string]string) map[string]string {
if src == nil {
return nil
}
dst := make(map[string]string, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
// min returns the smaller of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}

193
codec/heartbeat.go Normal file
View File

@@ -0,0 +1,193 @@
package codec
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/oarkflow/mq/consts"
)
// HeartbeatManager manages heartbeat messages for connection health monitoring
type HeartbeatManager struct {
codec *Codec
interval time.Duration
timeout time.Duration
conn net.Conn
stopChan chan struct{}
lastHeartbeat atomic.Int64
lastReceived atomic.Int64
onFailure func(error)
mu sync.RWMutex
isRunning bool
heartbeatsSent uint64
heartbeatsRecv uint64
failedHeartbeats uint64
}
// NewHeartbeatManager creates a new heartbeat manager
func NewHeartbeatManager(codec *Codec, conn net.Conn) *HeartbeatManager {
return &HeartbeatManager{
codec: codec,
interval: 30 * time.Second,
timeout: 90 * time.Second, // 3x interval
conn: conn,
stopChan: make(chan struct{}),
}
}
// SetInterval sets the heartbeat interval
func (hm *HeartbeatManager) SetInterval(interval time.Duration) {
hm.mu.Lock()
defer hm.mu.Unlock()
hm.interval = interval
}
// SetTimeout sets the heartbeat timeout
func (hm *HeartbeatManager) SetTimeout(timeout time.Duration) {
hm.mu.Lock()
defer hm.mu.Unlock()
hm.timeout = timeout
}
// SetOnFailure sets the callback function for heartbeat failures
func (hm *HeartbeatManager) SetOnFailure(fn func(error)) {
hm.mu.Lock()
defer hm.mu.Unlock()
hm.onFailure = fn
}
// Start starts the heartbeat monitoring
func (hm *HeartbeatManager) Start() {
hm.mu.Lock()
defer hm.mu.Unlock()
if hm.isRunning {
return
}
hm.isRunning = true
hm.lastHeartbeat.Store(time.Now().Unix())
hm.lastReceived.Store(time.Now().Unix())
go hm.sendHeartbeats()
go hm.monitorHeartbeats()
}
// Stop stops the heartbeat monitoring
func (hm *HeartbeatManager) Stop() {
hm.mu.Lock()
defer hm.mu.Unlock()
if !hm.isRunning {
return
}
hm.isRunning = false
close(hm.stopChan)
// Create a new stop channel for future use
hm.stopChan = make(chan struct{})
}
// IsRunning returns whether the heartbeat manager is running
func (hm *HeartbeatManager) IsRunning() bool {
hm.mu.RLock()
defer hm.mu.RUnlock()
return hm.isRunning
}
// GetStats returns heartbeat statistics
func (hm *HeartbeatManager) GetStats() map[string]uint64 {
return map[string]uint64{
"sent": atomic.LoadUint64(&hm.heartbeatsSent),
"received": atomic.LoadUint64(&hm.heartbeatsRecv),
"failed": atomic.LoadUint64(&hm.failedHeartbeats),
}
}
// RecordHeartbeat records a received heartbeat
func (hm *HeartbeatManager) RecordHeartbeat() {
hm.lastReceived.Store(time.Now().Unix())
atomic.AddUint64(&hm.heartbeatsRecv, 1)
}
// sendHeartbeats sends periodic heartbeat messages
func (hm *HeartbeatManager) sendHeartbeats() {
ticker := time.NewTicker(hm.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := hm.sendHeartbeat(); err != nil {
atomic.AddUint64(&hm.failedHeartbeats, 1)
hm.mu.RLock()
onFailure := hm.onFailure
hm.mu.RUnlock()
if onFailure != nil {
onFailure(err)
}
}
case <-hm.stopChan:
return
}
}
}
// monitorHeartbeats monitors the heartbeat health
func (hm *HeartbeatManager) monitorHeartbeats() {
ticker := time.NewTicker(hm.interval / 2)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now().Unix()
lastReceived := hm.lastReceived.Load()
// Check if we've exceeded the timeout
if now-lastReceived > int64(hm.timeout.Seconds()) {
err := fmt.Errorf("heartbeat timeout: last received %v seconds ago",
now-lastReceived)
atomic.AddUint64(&hm.failedHeartbeats, 1)
hm.mu.RLock()
onFailure := hm.onFailure
hm.mu.RUnlock()
if onFailure != nil {
onFailure(err)
}
}
case <-hm.stopChan:
return
}
}
}
// sendHeartbeat sends a single heartbeat message
func (hm *HeartbeatManager) sendHeartbeat() error {
msg := &Message{
Type: MessageTypeHeartbeat,
Command: consts.CMD(0), // Use appropriate heartbeat command from your consts
Version: ProtocolVersion,
Timestamp: time.Now().Unix(),
Headers: map[string]string{"type": "heartbeat"},
Payload: []byte{},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := hm.codec.SendMessage(ctx, hm.conn, msg)
if err == nil {
hm.lastHeartbeat.Store(time.Now().Unix())
atomic.AddUint64(&hm.heartbeatsSent, 1)
}
return err
}

View File

@@ -1,39 +1,418 @@
package codec
import (
"bytes"
"compress/gzip"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/oarkflow/json"
"golang.org/x/crypto/chacha20poly1305"
)
type MarshallerFunc func(v any) ([]byte, error)
// Error definitions for serialization
var (
ErrSerializationFailed = errors.New("serialization failed")
ErrDeserializationFailed = errors.New("deserialization failed")
ErrCompressionFailed = errors.New("compression failed")
ErrDecompressionFailed = errors.New("decompression failed")
ErrEncryptionFailed = errors.New("encryption failed")
ErrDecryptionFailed = errors.New("decryption failed")
ErrInvalidKey = errors.New("invalid encryption key")
)
type UnmarshallerFunc func(data []byte, v any) error
// ContentType represents the content type of serialized data
type ContentType string
const (
ContentTypeJSON ContentType = "application/json"
ContentTypeMsgPack ContentType = "application/msgpack"
ContentTypeCBOR ContentType = "application/cbor"
)
// Marshaller interface for pluggable serialization
type Marshaller interface {
Marshal(v any) ([]byte, error)
ContentType() ContentType
}
// Unmarshaller interface for pluggable deserialization
type Unmarshaller interface {
Unmarshal(data []byte, v any) error
ContentType() ContentType
}
// MarshallerFunc adapter
type MarshallerFunc func(v any) ([]byte, error)
func (f MarshallerFunc) Marshal(v any) ([]byte, error) {
return f(v)
}
func (f MarshallerFunc) ContentType() ContentType {
return ContentTypeJSON
}
// UnmarshallerFunc adapter
type UnmarshallerFunc func(data []byte, v any) error
func (f UnmarshallerFunc) Unmarshal(data []byte, v any) error {
return f(data, v)
}
var defaultMarshaller MarshallerFunc = json.Marshal
var defaultUnmarshaller UnmarshallerFunc = func(data []byte, v any) error {
return json.Unmarshal(data, v)
func (f UnmarshallerFunc) ContentType() ContentType {
return ContentTypeJSON
}
// SerializationConfig holds serialization configuration
type SerializationConfig struct {
EnableCompression bool
CompressionLevel int
MaxCompressionRatio float64
EnableEncryption bool
EncryptionKey []byte
PreferredCipher string // "chacha20poly1305" or "aes-gcm"
}
// DefaultSerializationConfig returns default configuration
func DefaultSerializationConfig() *SerializationConfig {
return &SerializationConfig{
EnableCompression: false,
CompressionLevel: gzip.DefaultCompression,
MaxCompressionRatio: 0.8, // Only compress if we save at least 20%
EnableEncryption: false,
PreferredCipher: "chacha20poly1305",
}
}
// SerializationManager manages serialization with configuration
type SerializationManager struct {
marshaller Marshaller
unmarshaller Unmarshaller
config *SerializationConfig
mu sync.RWMutex
cachedCipher cipher.AEAD
cipherMu sync.Mutex
}
// NewSerializationManager creates a new serialization manager
func NewSerializationManager(config *SerializationConfig) *SerializationManager {
if config == nil {
config = DefaultSerializationConfig()
}
return &SerializationManager{
marshaller: MarshallerFunc(json.Marshal),
unmarshaller: UnmarshallerFunc(func(data []byte, v any) error {
return json.Unmarshal(data, v)
}),
config: config,
}
}
// SetMarshaller sets custom marshaller
func (sm *SerializationManager) SetMarshaller(marshaller Marshaller) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.marshaller = marshaller
}
// SetUnmarshaller sets custom unmarshaller
func (sm *SerializationManager) SetUnmarshaller(unmarshaller Unmarshaller) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.unmarshaller = unmarshaller
}
// SetEncryptionKey sets the encryption key
func (sm *SerializationManager) SetEncryptionKey(key []byte) error {
if sm.config.EnableEncryption && len(key) == 0 {
return ErrInvalidKey
}
sm.mu.Lock()
defer sm.mu.Unlock()
sm.config.EncryptionKey = key
// Clear the cached cipher so it will be recreated with the new key
sm.cipherMu.Lock()
defer sm.cipherMu.Unlock()
sm.cachedCipher = nil
return nil
}
// Marshal serializes data with optional compression and encryption
func (sm *SerializationManager) Marshal(v any) ([]byte, error) {
sm.mu.RLock()
marshaller := sm.marshaller
config := sm.config
sm.mu.RUnlock()
// Serialize the data
data, err := marshaller.Marshal(v)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrSerializationFailed, err)
}
// Apply compression if enabled and beneficial
if config.EnableCompression && len(data) > 256 { // Only compress larger payloads
compressed, err := sm.compress(data)
if err != nil {
// Continue with uncompressed data on error
// but don't return error to allow for graceful degradation
} else if float64(len(compressed))/float64(len(data)) <= config.MaxCompressionRatio {
// Only use compression if it provides significant benefit
data = compressed
}
}
// Apply encryption if enabled
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
encrypted, err := sm.encrypt(data)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
}
data = encrypted
}
return data, nil
}
// Unmarshal deserializes data with optional decompression and decryption
func (sm *SerializationManager) Unmarshal(data []byte, v any) error {
if len(data) == 0 {
return fmt.Errorf("%w: empty data", ErrDeserializationFailed)
}
sm.mu.RLock()
unmarshaller := sm.unmarshaller
config := sm.config
sm.mu.RUnlock()
// Apply decryption if enabled
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
decrypted, err := sm.decrypt(data)
if err != nil {
return fmt.Errorf("%w: %v", ErrDecryptionFailed, err)
}
data = decrypted
}
// Try decompression if enabled
if config.EnableCompression && sm.isCompressed(data) {
decompressed, err := sm.decompress(data)
if err != nil {
return fmt.Errorf("%w: %v", ErrDecompressionFailed, err)
}
data = decompressed
}
// Deserialize the data
if err := unmarshaller.Unmarshal(data, v); err != nil {
return fmt.Errorf("%w: %v", ErrDeserializationFailed, err)
}
return nil
}
// compress compresses data using gzip
func (sm *SerializationManager) compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
// Add compression marker
buf.WriteByte(0x1f) // gzip magic number
writer, err := gzip.NewWriterLevel(&buf, sm.config.CompressionLevel)
if err != nil {
return nil, err
}
if _, err := writer.Write(data); err != nil {
writer.Close()
return nil, err
}
if err := writer.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decompress decompresses gzip data
func (sm *SerializationManager) decompress(data []byte) ([]byte, error) {
if len(data) < 1 {
return nil, fmt.Errorf("invalid compressed data")
}
// Remove compression marker
if data[0] != 0x1f {
return data, nil // Not compressed
}
reader, err := gzip.NewReader(bytes.NewReader(data[1:]))
if err != nil {
return nil, err
}
defer reader.Close()
var buf bytes.Buffer
if _, err := buf.ReadFrom(reader); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// isCompressed checks if data is compressed
func (sm *SerializationManager) isCompressed(data []byte) bool {
return len(data) > 0 && data[0] == 0x1f
}
// getCipher returns a cipher.AEAD instance for encryption/decryption
func (sm *SerializationManager) getCipher() (cipher.AEAD, error) {
sm.cipherMu.Lock()
defer sm.cipherMu.Unlock()
if sm.cachedCipher != nil {
return sm.cachedCipher, nil
}
var aead cipher.AEAD
var err error
sm.mu.RLock()
key := sm.config.EncryptionKey
preferredCipher := sm.config.PreferredCipher
sm.mu.RUnlock()
switch preferredCipher {
case "chacha20poly1305":
aead, err = chacha20poly1305.New(key)
case "aes-gcm":
block, e := aes.NewCipher(key)
if e != nil {
err = e
break
}
aead, err = cipher.NewGCM(block)
default:
// Default to ChaCha20-Poly1305
aead, err = chacha20poly1305.New(key)
}
if err != nil {
return nil, err
}
sm.cachedCipher = aead
return aead, nil
}
// encrypt encrypts data using authenticated encryption
func (sm *SerializationManager) encrypt(data []byte) ([]byte, error) {
aead, err := sm.getCipher()
if err != nil {
return nil, err
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
// Add timestamp to associated data for replay protection
timestampBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timestampBytes, uint64(time.Now().Unix()))
// Encrypt and authenticate
ciphertext := aead.Seal(nil, nonce, data, timestampBytes)
// Prepend nonce and timestamp to ciphertext
result := make([]byte, len(nonce)+len(timestampBytes)+len(ciphertext))
copy(result, nonce)
copy(result[len(nonce):], timestampBytes)
copy(result[len(nonce)+len(timestampBytes):], ciphertext)
return result, nil
}
// decrypt decrypts data using authenticated decryption
func (sm *SerializationManager) decrypt(data []byte) ([]byte, error) {
aead, err := sm.getCipher()
if err != nil {
return nil, err
}
// Extract nonce
nonceSize := aead.NonceSize()
if len(data) < nonceSize+8 { // 8 bytes for timestamp
return nil, fmt.Errorf("ciphertext too short")
}
nonce := data[:nonceSize]
timestampBytes := data[nonceSize : nonceSize+8]
ciphertext := data[nonceSize+8:]
// Decrypt and verify
plaintext, err := aead.Open(nil, nonce, ciphertext, timestampBytes)
if err != nil {
return nil, err
}
return plaintext, nil
}
// Global serialization manager instance
var globalSerializationManager = NewSerializationManager(nil)
var globalSerializationManagerMu sync.RWMutex
// Global functions for backward compatibility
func SetMarshaller(marshaller MarshallerFunc) {
defaultMarshaller = marshaller
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.SetMarshaller(marshaller)
}
func SetUnmarshaller(unmarshaller UnmarshallerFunc) {
defaultUnmarshaller = unmarshaller
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.SetUnmarshaller(unmarshaller)
}
func EnableCompression(enable bool) {
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.config.EnableCompression = enable
}
func EnableEncryption(enable bool, key []byte) error {
globalSerializationManagerMu.Lock()
defer globalSerializationManagerMu.Unlock()
globalSerializationManager.config.EnableEncryption = enable
if enable {
return globalSerializationManager.SetEncryptionKey(key)
}
return nil
}
func Marshal(v any) ([]byte, error) {
return defaultMarshaller(v)
globalSerializationManagerMu.RLock()
defer globalSerializationManagerMu.RUnlock()
return globalSerializationManager.Marshal(v)
}
func Unmarshal(data []byte, v any) error {
return defaultUnmarshaller(data, v)
globalSerializationManagerMu.RLock()
defer globalSerializationManagerMu.RUnlock()
return globalSerializationManager.Unmarshal(data, v)
}

210
codec/testing.go Normal file
View File

@@ -0,0 +1,210 @@
package codec
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/oarkflow/mq/consts"
)
// MockConn implements net.Conn for testing
type MockConn struct {
ReadBuffer *bytes.Buffer
WriteBuffer *bytes.Buffer
ReadDelay time.Duration
WriteDelay time.Duration
IsClosed bool
ReadErr error
WriteErr error
mu sync.Mutex
}
// NewMockConn creates a new mock connection
func NewMockConn() *MockConn {
return &MockConn{
ReadBuffer: bytes.NewBuffer(nil),
WriteBuffer: bytes.NewBuffer(nil),
}
}
// Read implements the net.Conn Read method
func (m *MockConn) Read(b []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed {
return 0, errors.New("connection closed")
}
if m.ReadErr != nil {
return 0, m.ReadErr
}
if m.ReadDelay > 0 {
time.Sleep(m.ReadDelay)
}
return m.ReadBuffer.Read(b)
}
// Write implements the net.Conn Write method
func (m *MockConn) Write(b []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed {
return 0, errors.New("connection closed")
}
if m.WriteErr != nil {
return 0, m.WriteErr
}
if m.WriteDelay > 0 {
time.Sleep(m.WriteDelay)
}
return m.WriteBuffer.Write(b)
}
// Close implements the net.Conn Close method
func (m *MockConn) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.IsClosed = true
return nil
}
// LocalAddr implements the net.Conn LocalAddr method
func (m *MockConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
}
// RemoteAddr implements the net.Conn RemoteAddr method
func (m *MockConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
}
// SetDeadline implements the net.Conn SetDeadline method
func (m *MockConn) SetDeadline(t time.Time) error {
return nil
}
// SetReadDeadline implements the net.Conn SetReadDeadline method
func (m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline implements the net.Conn SetWriteDeadline method
func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
// CodecTestSuite provides utilities for testing the codec
type CodecTestSuite struct {
Codec *Codec
Config *Config
}
// NewCodecTestSuite creates a new codec test suite
func NewCodecTestSuite() *CodecTestSuite {
config := DefaultConfig()
// Set smaller timeouts for testing
config.ReadTimeout = 500 * time.Millisecond
config.WriteTimeout = 500 * time.Millisecond
return &CodecTestSuite{
Codec: NewCodec(config),
Config: config,
}
}
// SendReceiveTest tests sending and receiving a message
func (ts *CodecTestSuite) SendReceiveTest(msg *Message) error {
conn := NewMockConn()
// Send the message
ctx := context.Background()
if err := ts.Codec.SendMessage(ctx, conn, msg); err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
// Move written data to read buffer to simulate network transport
conn.ReadBuffer.Write(conn.WriteBuffer.Bytes())
conn.WriteBuffer.Reset()
// Receive the message
received, err := ts.Codec.ReadMessage(ctx, conn)
if err != nil {
return fmt.Errorf("failed to receive message: %w", err)
}
// Validate the message
if received.Command != msg.Command {
return fmt.Errorf("command mismatch: got %v, want %v", received.Command, msg.Command)
}
if received.Queue != msg.Queue {
return fmt.Errorf("queue mismatch: got %v, want %v", received.Queue, msg.Queue)
}
if !bytes.Equal(received.Payload, msg.Payload) {
return fmt.Errorf("payload mismatch: got %d bytes, want %d bytes", len(received.Payload), len(msg.Payload))
}
return nil
}
// FragmentationTest tests the fragmentation and reassembly of large messages
func (ts *CodecTestSuite) FragmentationTest(payload []byte) error {
msg := &Message{
Command: consts.CMD(1), // Use appropriate command from your consts
Queue: "test_queue",
Headers: map[string]string{"test": "header"},
Payload: payload,
Version: ProtocolVersion,
Timestamp: time.Now().Unix(),
ID: "test-message-id",
}
conn := NewMockConn()
// Configure fragmentation
fm := NewFragmentManager(ts.Codec, ts.Config)
// Send the fragmented message
ctx := context.Background()
if err := fm.sendFragmentedMessage(ctx, conn, msg); err != nil {
return fmt.Errorf("failed to send fragmented message: %w", err)
}
// Move written data to read buffer to simulate network transport
conn.ReadBuffer.Write(conn.WriteBuffer.Bytes())
conn.WriteBuffer.Reset()
// Receive and reassemble the message
received, err := ts.Codec.ReadMessage(ctx, conn)
if err != nil {
return fmt.Errorf("failed to receive message: %w", err)
}
// Validate the reassembled message
if received.Command != msg.Command {
return fmt.Errorf("command mismatch: got %v, want %v", received.Command, msg.Command)
}
if received.Queue != msg.Queue {
return fmt.Errorf("queue mismatch: got %v, want %v", received.Queue, msg.Queue)
}
if !bytes.Equal(received.Payload, msg.Payload) {
return fmt.Errorf("payload mismatch: got %d bytes, want %d bytes", len(received.Payload), len(msg.Payload))
}
return nil
}

View File

@@ -1,99 +0,0 @@
{
"broker": {
"address": "localhost",
"port": 8080,
"max_connections": 1000,
"connection_timeout": "5s",
"read_timeout": "300s",
"write_timeout": "30s",
"idle_timeout": "600s",
"keep_alive": true,
"keep_alive_period": "60s",
"max_queue_depth": 10000,
"enable_dead_letter": true,
"dead_letter_max_retries": 3
},
"consumer": {
"enable_http_api": true,
"max_retries": 5,
"initial_delay": "2s",
"max_backoff": "30s",
"jitter_percent": 0.5,
"batch_size": 10,
"prefetch_count": 100,
"auto_ack": false,
"requeue_on_failure": true
},
"publisher": {
"enable_http_api": true,
"max_retries": 3,
"initial_delay": "1s",
"max_backoff": "10s",
"confirm_delivery": true,
"publish_timeout": "5s",
"connection_pool_size": 10
},
"pool": {
"queue_size": 1000,
"max_workers": 20,
"max_memory_load": 1073741824,
"idle_timeout": "300s",
"graceful_shutdown_timeout": "30s",
"task_timeout": "60s",
"enable_metrics": true,
"enable_diagnostics": true
},
"security": {
"enable_tls": false,
"tls_cert_path": "./certs/server.crt",
"tls_key_path": "./certs/server.key",
"tls_ca_path": "./certs/ca.crt",
"enable_auth": false,
"auth_provider": "jwt",
"jwt_secret": "your-secret-key",
"enable_encryption": false,
"encryption_key": "32-byte-encryption-key-here!!"
},
"monitoring": {
"metrics_port": 9090,
"health_check_port": 9091,
"enable_metrics": true,
"enable_health_checks": true,
"metrics_interval": "10s",
"health_check_interval": "30s",
"retention_period": "24h",
"enable_tracing": true,
"jaeger_endpoint": "http://localhost:14268/api/traces"
},
"persistence": {
"enable": true,
"provider": "postgres",
"connection_string": "postgres://user:password@localhost:5432/mq_db?sslmode=disable",
"max_connections": 50,
"connection_timeout": "30s",
"enable_migrations": true,
"backup_enabled": true,
"backup_interval": "6h"
},
"clustering": {
"enable": false,
"node_id": "node-1",
"cluster_name": "mq-cluster",
"peers": [ ],
"election_timeout": "5s",
"heartbeat_interval": "1s",
"enable_auto_discovery": false,
"discovery_port": 7946
},
"rate_limit": {
"broker_rate": 1000,
"broker_burst": 100,
"consumer_rate": 500,
"consumer_burst": 50,
"publisher_rate": 200,
"publisher_burst": 20,
"global_rate": 2000,
"global_burst": 200
},
"last_updated": "2025-07-29T00:00:00Z"
}

View File

@@ -152,7 +152,10 @@ func (c *Consumer) Metrics() Metrics {
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
msg, err := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
if err != nil {
return fmt.Errorf("error creating subscribe message: %v", err)
}
if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error while trying to subscribe: %v", err)
}
@@ -207,7 +210,14 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C
func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) {
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
reply, err := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err != nil {
c.logger.Error("Failed to create MESSAGE_ACK",
logger.Field{Key: "queue", Value: msg.Queue},
logger.Field{Key: "task_id", Value: taskID},
logger.Field{Key: "error", Value: err.Error()})
return
}
// Send with timeout to avoid blocking
sendCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
@@ -375,7 +385,14 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
}
}
bt, _ := json.Marshal(result)
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
reply, err := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
if err != nil {
c.logger.Error("Failed to create MESSAGE_RESPONSE",
logger.Field{Key: "topic", Value: result.Topic},
logger.Field{Key: "task_id", Value: result.TaskID},
logger.Field{Key: "error", Value: err.Error()})
return
}
sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -395,7 +412,14 @@ func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, er
// Send deny message asynchronously to avoid blocking
go func() {
headers := HeadersWithConsumerID(ctx, c.id)
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
reply, err := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if err != nil {
c.logger.Error("Failed to create MESSAGE_DENY",
logger.Field{Key: "queue", Value: queue},
logger.Field{Key: "task_id", Value: taskID},
logger.Field{Key: "error", Value: err.Error()})
return
}
sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -820,7 +844,10 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(cmd, nil, c.queue, headers)
msg, err := codec.NewMessage(cmd, nil, c.queue, headers)
if err != nil {
return fmt.Errorf("error creating operation message: %v", err)
}
return c.send(ctx, c.conn, msg)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/oarkflow/json"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/utils"
"github.com/oarkflow/jet"
@@ -19,7 +20,7 @@ import (
func main() {
flow := dag.NewDAG("Email Notification System", "email-notification", func(taskID string, result mq.Result) {
fmt.Printf("Email notification workflow completed for task %s: %s\n", taskID, string(result.Payload))
fmt.Printf("Email notification workflow completed for task %s: %s\n", taskID, string(utils.RemoveRecursiveFromJSON(result.Payload, "html_content")))
})
// Add workflow nodes

View File

@@ -1,557 +0,0 @@
package main
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/logger"
)
// ExampleProcessor demonstrates a custom processor with debugging
type ExampleProcessor struct {
name string
tags []string
}
func NewExampleProcessor(name string) *ExampleProcessor {
return &ExampleProcessor{
name: name,
tags: []string{"example", "demo"},
}
}
func (p *ExampleProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
// Simulate processing time
time.Sleep(100 * time.Millisecond)
// Add some example processing logic
var data map[string]interface{}
if err := task.UnmarshalPayload(&data); err != nil {
return mq.Result{Error: err}
}
// Process the data
data["processed_by"] = p.name
data["processed_at"] = time.Now()
payload, _ := task.MarshalPayload(data)
return mq.Result{Payload: payload}
}
func (p *ExampleProcessor) SetConfig(payload dag.Payload) {}
func (p *ExampleProcessor) SetTags(tags ...string) { p.tags = append(p.tags, tags...) }
func (p *ExampleProcessor) GetTags() []string { return p.tags }
func (p *ExampleProcessor) Consume(ctx context.Context) error { return nil }
func (p *ExampleProcessor) Pause(ctx context.Context) error { return nil }
func (p *ExampleProcessor) Resume(ctx context.Context) error { return nil }
func (p *ExampleProcessor) Stop(ctx context.Context) error { return nil }
func (p *ExampleProcessor) Close() error { return nil }
func (p *ExampleProcessor) GetType() string { return "example" }
func (p *ExampleProcessor) GetKey() string { return p.name }
func (p *ExampleProcessor) SetKey(key string) { p.name = key }
// CustomActivityHook demonstrates custom activity processing
type CustomActivityHook struct {
logger logger.Logger
}
func (h *CustomActivityHook) OnActivity(entry dag.ActivityEntry) error {
// Custom processing of activity entries
if entry.Level == dag.ActivityLevelError {
h.logger.Error("Critical activity detected",
logger.Field{Key: "activity_id", Value: entry.ID},
logger.Field{Key: "dag_name", Value: entry.DAGName},
logger.Field{Key: "message", Value: entry.Message},
)
// Here you could send notifications, trigger alerts, etc.
}
return nil
}
// CustomAlertHandler demonstrates custom alert handling
type CustomAlertHandler struct {
logger logger.Logger
}
func (h *CustomAlertHandler) HandleAlert(alert dag.Alert) error {
h.logger.Warn("DAG Alert received",
logger.Field{Key: "type", Value: alert.Type},
logger.Field{Key: "severity", Value: alert.Severity},
logger.Field{Key: "message", Value: alert.Message},
)
// Here you could integrate with external alerting systems
// like Slack, PagerDuty, email, etc.
return nil
}
func main() {
// Initialize logger
log := logger.New(logger.Config{
Level: logger.LevelInfo,
Format: logger.FormatJSON,
})
// Create a comprehensive DAG with all enhanced features
server := mq.NewServer("demo", ":0", log)
// Create DAG with comprehensive configuration
dagInstance := dag.NewDAG("production-workflow", "workflow-key", func(ctx context.Context, result mq.Result) {
log.Info("Workflow completed",
logger.Field{Key: "result", Value: string(result.Payload)},
)
})
// Initialize all enhanced components
setupEnhancedDAG(dagInstance, log)
// Build the workflow
buildWorkflow(dagInstance, log)
// Start the server and DAG
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
if err := server.Start(ctx); err != nil {
log.Error("Server failed to start", logger.Field{Key: "error", Value: err.Error()})
}
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Start enhanced DAG features
startEnhancedFeatures(ctx, dagInstance, log)
// Set up HTTP API for monitoring and management
setupHTTPAPI(dagInstance, log)
// Start the HTTP server
go func() {
log.Info("Starting HTTP server on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Error("HTTP server failed", logger.Field{Key: "error", Value: err.Error()})
}
}()
// Demonstrate the enhanced features
demonstrateFeatures(ctx, dagInstance, log)
// Wait for shutdown signal
waitForShutdown(ctx, cancel, dagInstance, server, log)
}
func setupEnhancedDAG(dagInstance *dag.DAG, log logger.Logger) {
// Initialize activity logger with memory persistence
activityConfig := dag.DefaultActivityLoggerConfig()
activityConfig.BufferSize = 500
activityConfig.FlushInterval = 2 * time.Second
persistence := dag.NewMemoryActivityPersistence()
dagInstance.InitializeActivityLogger(activityConfig, persistence)
// Add custom activity hook
customHook := &CustomActivityHook{logger: log}
dagInstance.AddActivityHook(customHook)
// Initialize monitoring with comprehensive configuration
monitorConfig := dag.MonitoringConfig{
MetricsInterval: 5 * time.Second,
EnableHealthCheck: true,
BufferSize: 1000,
}
alertThresholds := &dag.AlertThresholds{
MaxFailureRate: 0.1, // 10%
MaxExecutionTime: 30 * time.Second,
MaxTasksInProgress: 100,
MinSuccessRate: 0.9, // 90%
MaxNodeFailures: 5,
HealthCheckInterval: 10 * time.Second,
}
dagInstance.InitializeMonitoring(monitorConfig, alertThresholds)
// Add custom alert handler
customAlertHandler := &CustomAlertHandler{logger: log}
dagInstance.AddAlertHandler(customAlertHandler)
// Initialize configuration management
dagInstance.InitializeConfigManager()
// Set up rate limiting
dagInstance.InitializeRateLimiter()
dagInstance.SetRateLimit("validate", 10.0, 5) // 10 req/sec, burst 5
dagInstance.SetRateLimit("process", 20.0, 10) // 20 req/sec, burst 10
dagInstance.SetRateLimit("finalize", 5.0, 2) // 5 req/sec, burst 2
// Initialize retry management
retryConfig := &dag.RetryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Second,
MaxDelay: 10 * time.Second,
BackoffFactor: 2.0,
Jitter: true,
RetryCondition: func(err error) bool {
// Custom retry condition - retry on specific errors
return err != nil && err.Error() != "permanent_failure"
},
}
dagInstance.InitializeRetryManager(retryConfig)
// Initialize transaction management
txConfig := dag.TransactionConfig{
DefaultTimeout: 5 * time.Minute,
CleanupInterval: 10 * time.Minute,
}
dagInstance.InitializeTransactionManager(txConfig)
// Initialize cleanup management
cleanupConfig := dag.CleanupConfig{
Interval: 5 * time.Minute,
TaskRetentionPeriod: 1 * time.Hour,
ResultRetentionPeriod: 2 * time.Hour,
MaxRetainedTasks: 1000,
}
dagInstance.InitializeCleanupManager(cleanupConfig)
// Initialize performance optimizer
dagInstance.InitializePerformanceOptimizer()
// Set up webhook manager for external notifications
httpClient := dag.NewSimpleHTTPClient(30 * time.Second)
webhookManager := dag.NewWebhookManager(httpClient, log)
// Add webhook for task completion events
webhookConfig := dag.WebhookConfig{
URL: "https://api.example.com/dag-events", // Replace with actual endpoint
Headers: map[string]string{"Authorization": "Bearer your-token"},
RetryCount: 3,
Events: []string{"task_completed", "task_failed", "dag_completed"},
}
webhookManager.AddWebhook("task_completed", webhookConfig)
dagInstance.SetWebhookManager(webhookManager)
log.Info("Enhanced DAG features initialized successfully")
}
func buildWorkflow(dagInstance *dag.DAG, log logger.Logger) {
// Create processors for each step
validator := NewExampleProcessor("validator")
processor := NewExampleProcessor("processor")
enricher := NewExampleProcessor("enricher")
finalizer := NewExampleProcessor("finalizer")
// Build the workflow with retry configurations
retryConfig := &dag.RetryConfig{
MaxRetries: 2,
InitialDelay: 500 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffFactor: 2.0,
}
dagInstance.
AddNodeWithRetry(dag.Function, "Validate Input", "validate", validator, retryConfig, true).
AddNodeWithRetry(dag.Function, "Process Data", "process", processor, retryConfig).
AddNodeWithRetry(dag.Function, "Enrich Data", "enrich", enricher, retryConfig).
AddNodeWithRetry(dag.Function, "Finalize", "finalize", finalizer, retryConfig).
Connect("validate", "process").
Connect("process", "enrich").
Connect("enrich", "finalize")
// Add conditional connections
dagInstance.AddCondition("validate", "success", "process")
dagInstance.AddCondition("validate", "failure", "finalize") // Skip to finalize on validation failure
// Validate the DAG structure
if err := dagInstance.ValidateDAG(); err != nil {
log.Error("DAG validation failed", logger.Field{Key: "error", Value: err.Error()})
os.Exit(1)
}
log.Info("Workflow built and validated successfully")
}
func startEnhancedFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) {
// Start monitoring
dagInstance.StartMonitoring(ctx)
// Start cleanup manager
dagInstance.StartCleanup(ctx)
// Enable batch processing
dagInstance.SetBatchProcessingEnabled(true)
log.Info("Enhanced features started")
}
func setupHTTPAPI(dagInstance *dag.DAG, log logger.Logger) {
// Set up standard DAG handlers
dagInstance.Handlers(http.DefaultServeMux, "/dag")
// Set up enhanced API endpoints
enhancedAPI := dag.NewEnhancedAPIHandler(dagInstance)
enhancedAPI.RegisterRoutes(http.DefaultServeMux)
// Custom endpoints for demonstration
http.HandleFunc("/demo/activities", func(w http.ResponseWriter, r *http.Request) {
filter := dag.ActivityFilter{
Limit: 50,
}
activities, err := dagInstance.GetActivities(filter)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, activities); err != nil {
log.Error("Failed to write activities response", logger.Field{Key: "error", Value: err.Error()})
}
})
http.HandleFunc("/demo/stats", func(w http.ResponseWriter, r *http.Request) {
stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, stats); err != nil {
log.Error("Failed to write stats response", logger.Field{Key: "error", Value: err.Error()})
}
})
log.Info("HTTP API endpoints configured")
}
func demonstrateFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) {
log.Info("Demonstrating enhanced DAG features...")
// 1. Process a successful task
log.Info("Processing successful task...")
processTask(ctx, dagInstance, map[string]interface{}{
"id": "task-001",
"data": "valid input data",
"type": "success",
}, log)
// 2. Process a task that will fail
log.Info("Processing failing task...")
processTask(ctx, dagInstance, map[string]interface{}{
"id": "task-002",
"data": nil, // This will cause processing issues
"type": "failure",
}, log)
// 3. Process with transaction
log.Info("Processing with transaction...")
processWithTransaction(ctx, dagInstance, map[string]interface{}{
"id": "task-003",
"data": "transaction data",
"type": "transaction",
}, log)
// 4. Demonstrate rate limiting
log.Info("Demonstrating rate limiting...")
demonstrateRateLimiting(ctx, dagInstance, log)
// 5. Show monitoring metrics
time.Sleep(2 * time.Second) // Allow time for metrics to accumulate
showMetrics(dagInstance, log)
// 6. Show activity logs
showActivityLogs(dagInstance, log)
}
func processTask(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) {
// Add context information
ctx = context.WithValue(ctx, "user_id", "demo-user")
ctx = context.WithValue(ctx, "session_id", "demo-session")
ctx = context.WithValue(ctx, "trace_id", mq.NewID())
result := dagInstance.Process(ctx, payload)
if result.Error != nil {
log.Error("Task processing failed",
logger.Field{Key: "error", Value: result.Error.Error()},
logger.Field{Key: "payload", Value: payload},
)
} else {
log.Info("Task processed successfully",
logger.Field{Key: "result_size", Value: len(result.Payload)},
)
}
}
func processWithTransaction(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) {
taskID := fmt.Sprintf("tx-%s", mq.NewID())
// Begin transaction
tx := dagInstance.BeginTransaction(taskID)
if tx == nil {
log.Error("Failed to begin transaction")
return
}
// Add transaction context
ctx = context.WithValue(ctx, "transaction_id", tx.ID)
ctx = context.WithValue(ctx, "task_id", taskID)
// Process the task
result := dagInstance.Process(ctx, payload)
// Commit or rollback based on result
if result.Error != nil {
if err := dagInstance.RollbackTransaction(tx.ID); err != nil {
log.Error("Failed to rollback transaction",
logger.Field{Key: "tx_id", Value: tx.ID},
logger.Field{Key: "error", Value: err.Error()},
)
} else {
log.Info("Transaction rolled back",
logger.Field{Key: "tx_id", Value: tx.ID},
)
}
} else {
if err := dagInstance.CommitTransaction(tx.ID); err != nil {
log.Error("Failed to commit transaction",
logger.Field{Key: "tx_id", Value: tx.ID},
logger.Field{Key: "error", Value: err.Error()},
)
} else {
log.Info("Transaction committed",
logger.Field{Key: "tx_id", Value: tx.ID},
)
}
}
}
func demonstrateRateLimiting(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) {
// Try to exceed rate limits
for i := 0; i < 15; i++ {
allowed := dagInstance.CheckRateLimit("validate")
log.Info("Rate limit check",
logger.Field{Key: "attempt", Value: i + 1},
logger.Field{Key: "allowed", Value: allowed},
)
if allowed {
processTask(ctx, dagInstance, map[string]interface{}{
"id": fmt.Sprintf("rate-test-%d", i),
"data": "rate limiting test",
}, log)
}
time.Sleep(100 * time.Millisecond)
}
}
func showMetrics(dagInstance *dag.DAG, log logger.Logger) {
metrics := dagInstance.GetMonitoringMetrics()
if metrics != nil {
log.Info("Current DAG Metrics",
logger.Field{Key: "total_tasks", Value: metrics.TasksTotal},
logger.Field{Key: "completed_tasks", Value: metrics.TasksCompleted},
logger.Field{Key: "failed_tasks", Value: metrics.TasksFailed},
logger.Field{Key: "tasks_in_progress", Value: metrics.TasksInProgress},
logger.Field{Key: "avg_execution_time", Value: metrics.AverageExecutionTime.String()},
)
// Show node-specific metrics
for nodeID := range map[string]bool{"validate": true, "process": true, "enrich": true, "finalize": true} {
if nodeStats := dagInstance.GetNodeStats(nodeID); nodeStats != nil {
log.Info("Node Metrics",
logger.Field{Key: "node_id", Value: nodeID},
logger.Field{Key: "executions", Value: nodeStats.TotalExecutions},
logger.Field{Key: "failures", Value: nodeStats.FailureCount},
logger.Field{Key: "avg_duration", Value: nodeStats.AverageExecutionTime.String()},
)
}
}
} else {
log.Warn("Monitoring metrics not available")
}
}
func showActivityLogs(dagInstance *dag.DAG, log logger.Logger) {
// Get recent activities
filter := dag.ActivityFilter{
Limit: 10,
SortBy: "timestamp",
SortOrder: "desc",
}
activities, err := dagInstance.GetActivities(filter)
if err != nil {
log.Error("Failed to get activities", logger.Field{Key: "error", Value: err.Error()})
return
}
log.Info("Recent Activities", logger.Field{Key: "count", Value: len(activities)})
for _, activity := range activities {
log.Info("Activity",
logger.Field{Key: "id", Value: activity.ID},
logger.Field{Key: "type", Value: string(activity.Type)},
logger.Field{Key: "level", Value: string(activity.Level)},
logger.Field{Key: "message", Value: activity.Message},
logger.Field{Key: "task_id", Value: activity.TaskID},
logger.Field{Key: "node_id", Value: activity.NodeID},
)
}
// Get activity statistics
stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{})
if err != nil {
log.Error("Failed to get activity stats", logger.Field{Key: "error", Value: err.Error()})
return
}
log.Info("Activity Statistics",
logger.Field{Key: "total_activities", Value: stats.TotalActivities},
logger.Field{Key: "success_rate", Value: fmt.Sprintf("%.2f%%", stats.SuccessRate*100)},
logger.Field{Key: "failure_rate", Value: fmt.Sprintf("%.2f%%", stats.FailureRate*100)},
logger.Field{Key: "avg_duration", Value: stats.AverageDuration.String()},
)
}
func waitForShutdown(ctx context.Context, cancel context.CancelFunc, dagInstance *dag.DAG, server *mq.Server, log logger.Logger) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
log.Info("DAG system is running. Available endpoints:",
logger.Field{Key: "workflow", Value: "http://localhost:8080/dag/"},
logger.Field{Key: "process", Value: "http://localhost:8080/dag/process"},
logger.Field{Key: "metrics", Value: "http://localhost:8080/api/dag/metrics"},
logger.Field{Key: "health", Value: "http://localhost:8080/api/dag/health"},
logger.Field{Key: "activities", Value: "http://localhost:8080/demo/activities"},
logger.Field{Key: "stats", Value: "http://localhost:8080/demo/stats"},
)
<-sigChan
log.Info("Shutdown signal received, cleaning up...")
// Graceful shutdown
cancel()
// Stop enhanced features
dagInstance.StopEnhanced(ctx)
// Stop server
if err := server.Stop(ctx); err != nil {
log.Error("Error stopping server", logger.Field{Key: "error", Value: err.Error()})
}
log.Info("Shutdown complete")
}

View File

@@ -13,7 +13,7 @@ func main() {
Payload: payload,
}
publisher := mq.NewPublisher("publish-1", mq.WithBrokerURL(":8081"))
for i := 0; i < 10000000; i++ {
for i := 0; i < 2; i++ {
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
err := publisher.Publish(context.Background(), task, "queue1")
if err != nil {

View File

@@ -10,6 +10,7 @@ import (
"github.com/oarkflow/json"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/utils"
"github.com/oarkflow/jet"
@@ -19,7 +20,7 @@ import (
func main() {
flow := dag.NewDAG("SMS Sender", "sms-sender", func(taskID string, result mq.Result) {
fmt.Printf("SMS workflow completed for task %s: %s\n", taskID, string(result.Payload))
fmt.Printf("SMS workflow completed for task %s: %s\n", taskID, string(utils.RemoveRecursiveFromJSON(result.Payload, "html_content")))
})
// Add SMS workflow nodes

6
go.mod
View File

@@ -15,6 +15,8 @@ require (
github.com/oarkflow/log v1.0.79
github.com/oarkflow/xid v1.2.5
github.com/prometheus/client_golang v1.21.1
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.33.0
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
golang.org/x/time v0.11.0
)
@@ -23,13 +25,16 @@ require (
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/goccy/go-reflect v1.2.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.63.0 // indirect
github.com/prometheus/procfs v0.16.0 // indirect
@@ -38,4 +43,5 @@ require (
github.com/valyala/fasthttp v1.59.0 // indirect
golang.org/x/sys v0.31.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

12
go.sum
View File

@@ -4,6 +4,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/goccy/go-reflect v1.2.0 h1:O0T8rZCuNmGXewnATuKYnkL0xm6o8UNOJZd/gOkb9ms=
@@ -18,6 +19,10 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
@@ -59,6 +64,8 @@ github.com/prometheus/procfs v0.16.0/go.mod h1:8veyXUu3nGP7oaCxhX6yeaM5u4stL2FeM
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
@@ -67,6 +74,8 @@ github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDp
github.com/valyala/fasthttp v1.59.0/go.mod h1:GTxNb9Bc6r2a9D0TWNSPwDz78UxnTGBViY3xZNEqyYU=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -76,5 +85,8 @@ golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

31
mq.go
View File

@@ -688,7 +688,10 @@ func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
if err != nil {
return err
}
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
msg, err := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
if err != nil {
return fmt.Errorf("failed to create PUBLISH message: %w", err)
}
b.broadcastToConsumers(msg)
return nil
}
@@ -698,7 +701,11 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err != nil {
log.Printf("Error creating PUBLISH_ACK message: %v\n", err)
return
}
if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
}
@@ -713,7 +720,11 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
ack, err := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
if err != nil {
log.Printf("Error creating SUBSCRIBE_ACK message: %v\n", err)
return
}
if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
}
@@ -894,8 +905,12 @@ func (b *Broker) handleConsumer(
fn := func(queue *Queue) {
con, ok := queue.consumers.Get(consumerID)
if ok {
ack := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
err := b.send(ctx, con.conn, ack)
ack, err := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
if err != nil {
log.Printf("Error creating message for consumer %s: %v", consumerID, err)
return
}
err = b.send(ctx, con.conn, ack)
if err == nil {
con.state = state
}
@@ -921,7 +936,11 @@ func (b *Broker) UpdateConsumer(ctx context.Context, consumerID string, config D
fn := func(queue *Queue) error {
con, ok := queue.consumers.Get(consumerID)
if ok {
ack := codec.NewMessage(consts.CONSUMER_UPDATE, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
ack, err := codec.NewMessage(consts.CONSUMER_UPDATE, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
if err != nil {
log.Printf("Error creating message for consumer %s: %v", consumerID, err)
return err
}
return b.send(ctx, con.conn, ack)
}
return nil

View File

@@ -478,6 +478,7 @@ func (wp *Pool) processNextBatch() {
if len(tasks) > 0 {
for _, task := range tasks {
if task != nil && !wp.gracefulShutdown {
wp.taskCompletionNotifier.Add(1)
wp.handleTask(task)
}
}
@@ -487,6 +488,8 @@ func (wp *Pool) processNextBatch() {
func (wp *Pool) handleTask(task *QueueTask) {
if task == nil || task.payload == nil {
wp.logger.Warn().Msg("Received nil task or payload")
// Only call Done if Add was called (which is now only for actual tasks)
wp.taskCompletionNotifier.Done()
return
}
@@ -803,9 +806,6 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
wp.taskAvailableCond.Signal()
wp.taskAvailableCond.L.Unlock()
// Track pending task
wp.taskCompletionNotifier.Add(1)
// Update metrics
atomic.AddInt64(&wp.metrics.TotalScheduled, 1)

View File

@@ -111,7 +111,10 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
if err != nil {
return err
}
msg := codec.NewMessage(command, payload, queue, headers)
msg, err := codec.NewMessage(command, payload, queue, headers)
if err != nil {
return err
}
if err := codec.SendMessage(ctx, conn, msg); err != nil {
return err
}

45
utils/json.go Normal file
View File

@@ -0,0 +1,45 @@
package utils
import (
"github.com/oarkflow/json"
"github.com/oarkflow/json/jsonparser"
)
func RemoveFromJSONBye(jsonStr json.RawMessage, key ...string) json.RawMessage {
return jsonparser.Delete(jsonStr, key...)
}
func RemoveRecursiveFromJSON(jsonStr json.RawMessage, key ...string) json.RawMessage {
var data any
if err := json.Unmarshal(jsonStr, &data); err != nil {
return jsonStr
}
for _, k := range key {
data = removeKeyRecursive(data, k)
}
result, err := json.Marshal(data)
if err != nil {
return jsonStr
}
return result
}
func removeKeyRecursive(data any, key string) any {
switch v := data.(type) {
case map[string]any:
delete(v, key)
for k, val := range v {
v[k] = removeKeyRecursive(val, key)
}
return v
case []any:
for i, item := range v {
v[i] = removeKeyRecursive(item, key)
}
return v
default:
return v
}
}