diff --git a/apperror/errors.go b/apperror/errors.go deleted file mode 100644 index 8c4485f..0000000 --- a/apperror/errors.go +++ /dev/null @@ -1,343 +0,0 @@ -// apperror/apperror.go -package apperror - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "path/filepath" - "runtime" - "strings" - "sync" -) - -// APP_ENV values -const ( - EnvDevelopment = "development" - EnvStaging = "staging" - EnvProduction = "production" -) - -// AppError defines a structured application error -type AppError struct { - Code string `json:"code"` // 9-digit code: XXX|AA|DD|YY - Message string `json:"message"` // human-readable message - StatusCode int `json:"-"` // HTTP status, not serialized - Err error `json:"-"` // wrapped error, not serialized - Metadata map[string]any `json:"metadata,omitempty"` // optional extra info - StackTrace []string `json:"stackTrace,omitempty"` -} - -// Error implements error interface -func (e *AppError) Error() string { - if e.Err != nil { - return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Err) - } - return fmt.Sprintf("[%s] %s", e.Code, e.Message) -} - -// Unwrap enables errors.Is / errors.As -func (e *AppError) Unwrap() error { - return e.Err -} - -// WithMetadata returns a shallow copy with added metadata key/value -func (e *AppError) WithMetadata(key string, val any) *AppError { - newMD := make(map[string]any, len(e.Metadata)+1) - for k, v := range e.Metadata { - newMD[k] = v - } - newMD[key] = val - - return &AppError{ - Code: e.Code, - Message: e.Message, - StatusCode: e.StatusCode, - Err: e.Err, - Metadata: newMD, - StackTrace: e.StackTrace, - } -} - -// GetStackTraceArray returns the error stack trace as an array of strings -func (e *AppError) GetStackTraceArray() []string { - return e.StackTrace -} - -// GetStackTraceString returns the error stack trace as a single string -func (e *AppError) GetStackTraceString() string { - return strings.Join(e.StackTrace, "\n") -} - -// captureStackTrace returns a slice of strings representing the stack trace. -func captureStackTrace() []string { - const depth = 32 - var pcs [depth]uintptr - n := runtime.Callers(3, pcs[:]) - frames := runtime.CallersFrames(pcs[:n]) - isDebug := os.Getenv("APP_DEBUG") == "true" - var stack []string - for { - frame, more := frames.Next() - var file string - if !isDebug { - file = "/" + filepath.Base(frame.File) - } else { - file = frame.File - } - if strings.HasSuffix(file, ".go") { - file = strings.TrimSuffix(file, ".go") + ".sec" - } - stack = append(stack, fmt.Sprintf("%s:%d %s", file, frame.Line, frame.Function)) - if !more { - break - } - } - return stack -} - -// buildCode constructs a 9-digit code: XXX|AA|DD|YY -func buildCode(httpCode, appCode, domainCode, errCode int) string { - return fmt.Sprintf("%03d%02d%02d%02d", httpCode, appCode, domainCode, errCode) -} - -// New creates a fresh AppError -func New(httpCode, appCode, domainCode, errCode int, msg string) *AppError { - return &AppError{ - Code: buildCode(httpCode, appCode, domainCode, errCode), - Message: msg, - StatusCode: httpCode, - // Prototype: no StackTrace captured at registration time. - } -} - -// Modify Wrap to always capture a fresh stack trace. -func Wrap(err error, httpCode, appCode, domainCode, errCode int, msg string) *AppError { - return &AppError{ - Code: buildCode(httpCode, appCode, domainCode, errCode), - Message: msg, - StatusCode: httpCode, - Err: err, - StackTrace: captureStackTrace(), - } -} - -// New helper: Instance attaches the runtime stack trace to a prototype error. -func Instance(e *AppError) *AppError { - // Create a shallow copy and attach the current stack trace. - copyE := *e - copyE.StackTrace = captureStackTrace() - return ©E -} - -// Modify toAppError to instance a prototype if it lacks a stack trace. -func toAppError(err error) *AppError { - if err == nil { - return nil - } - var ae *AppError - if errors.As(err, &ae) { - if len(ae.StackTrace) == 0 { // Prototype without context. - return Instance(ae) - } - return ae - } - // fallback to internal error 500|00|00|00 with fresh stack trace. - return Wrap(err, http.StatusInternalServerError, 0, 0, 0, "Internal server error") -} - -// onError, if set, is called before writing any JSON error -var onError func(*AppError) - -func OnError(hook func(*AppError)) { - onError = hook -} - -// WriteJSONError writes an error as JSON, includes X-Request-ID, hides details in production -func WriteJSONError(w http.ResponseWriter, r *http.Request, err error) { - appErr := toAppError(err) - - // attach request ID - if rid := r.Header.Get("X-Request-ID"); rid != "" { - appErr = appErr.WithMetadata("request_id", rid) - } - // hook - if onError != nil { - onError(appErr) - } - // If no stack trace is present, capture current context stack trace. - if os.Getenv("APP_ENV") != EnvProduction { - appErr.StackTrace = captureStackTrace() - } - - fmt.Println(appErr.StackTrace) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(appErr.StatusCode) - - resp := map[string]any{ - "code": appErr.Code, - "message": appErr.Message, - } - if len(appErr.Metadata) > 0 { - resp["metadata"] = appErr.Metadata - } - if os.Getenv("APP_ENV") != EnvProduction { - resp["stack"] = appErr.StackTrace - } - if appErr.Err != nil { - resp["details"] = appErr.Err.Error() - } - - _ = json.NewEncoder(w).Encode(resp) -} - -type ErrorRegistry struct { - registry map[string]*AppError - mu sync.RWMutex -} - -func (er *ErrorRegistry) Get(name string) (*AppError, bool) { - er.mu.RLock() - defer er.mu.RUnlock() - e, ok := er.registry[name] - return e, ok -} - -func (er *ErrorRegistry) Set(name string, e *AppError) { - er.mu.Lock() - defer er.mu.Unlock() - er.registry[name] = e -} - -func (er *ErrorRegistry) Delete(name string) { - er.mu.Lock() - defer er.mu.Unlock() - delete(er.registry, name) -} - -func (er *ErrorRegistry) List() []*AppError { - er.mu.RLock() - defer er.mu.RUnlock() - out := make([]*AppError, 0, len(er.registry)) - for _, e := range er.registry { - // create a shallow copy and remove the StackTrace for listing - copyE := *e - copyE.StackTrace = nil - out = append(out, ©E) - } - return out -} - -func (er *ErrorRegistry) GetByCode(code string) (*AppError, bool) { - er.mu.RLock() - defer er.mu.RUnlock() - for _, e := range er.registry { - if e.Code == code { - return e, true - } - } - return nil, false -} - -var ( - registry *ErrorRegistry -) - -// Register adds a named error; fails if name exists -func Register(name string, e *AppError) error { - if name == "" { - return fmt.Errorf("error name cannot be empty") - } - registry.Set(name, e) - return nil -} - -// Update replaces an existing named error; fails if not found -func Update(name string, e *AppError) error { - if name == "" { - return fmt.Errorf("error name cannot be empty") - } - registry.Set(name, e) - return nil -} - -// Unregister removes a named error -func Unregister(name string) error { - if name == "" { - return fmt.Errorf("error name cannot be empty") - } - registry.Delete(name) - return nil -} - -// Get retrieves a named error -func Get(name string) (*AppError, bool) { - return registry.Get(name) -} - -// GetByCode retrieves an error by its 9-digit code -func GetByCode(code string) (*AppError, bool) { - if code == "" { - return nil, false - } - return registry.GetByCode(code) -} - -// List returns all registered errors -func List() []*AppError { - return registry.List() -} - -// Is/As shortcuts updated to check all registered errors -func Is(err, target error) bool { - if errors.Is(err, target) { - return true - } - registry.mu.RLock() - defer registry.mu.RUnlock() - for _, e := range registry.registry { - if errors.Is(err, e) || errors.Is(e, target) { - return true - } - } - return false -} - -func As(err error, target any) bool { - if errors.As(err, target) { - return true - } - registry.mu.RLock() - defer registry.mu.RUnlock() - for _, e := range registry.registry { - if errors.As(err, target) || errors.As(e, target) { - return true - } - } - return false -} - -// HTTPMiddleware catches panics and converts to JSON 500 -func HTTPMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - if rec := recover(); rec != nil { - p := fmt.Errorf("panic: %v", rec) - WriteJSONError(w, r, Wrap(p, http.StatusInternalServerError, 0, 0, 0, "Internal server error")) - } - }() - next.ServeHTTP(w, r) - }) -} - -// preload some common errors (with 2-digit app/domain codes) -func init() { - registry = &ErrorRegistry{registry: make(map[string]*AppError)} - _ = Register("ErrNotFound", New(http.StatusNotFound, 1, 1, 1, "Resource not found")) // → "404010101" - _ = Register("ErrInvalidInput", New(http.StatusBadRequest, 1, 1, 2, "Invalid input provided")) // → "400010102" - _ = Register("ErrInternal", New(http.StatusInternalServerError, 1, 1, 0, "Internal server error")) // → "500010100" - _ = Register("ErrUnauthorized", New(http.StatusUnauthorized, 1, 1, 3, "Unauthorized")) // → "401010103" - _ = Register("ErrForbidden", New(http.StatusForbidden, 1, 1, 4, "Forbidden")) // → "403010104" -} diff --git a/codec/README.md b/codec/README.md new file mode 100644 index 0000000..60b4fa1 --- /dev/null +++ b/codec/README.md @@ -0,0 +1,181 @@ +# Message Queue Codec + +This package provides a robust, production-ready codec implementation for serializing, transmitting, and deserializing messages in a distributed messaging system. + +## Features + +- **Message Validation**: Comprehensive validation of message format and content +- **Efficient Serialization**: Pluggable serialization with JSON as default +- **Compression**: Optional payload compression for large messages +- **Encryption**: Optional message encryption for sensitive data +- **Large Message Support**: Automatic fragmentation and reassembly of large messages +- **Connection Health**: Heartbeat mechanism for connection monitoring +- **Performance Optimized**: Buffer pooling, efficient memory usage +- **Robust Error Handling**: Detailed error types and error wrapping +- **Timeout Management**: Context-aware deadline handling +- **Observability**: Built-in statistics tracking + +## Usage + +### Basic Message Sending/Receiving + +```go +// Create a message +msg, err := codec.NewMessage(consts.CmdPublish, payload, "my-queue", headers) +if err != nil { + log.Fatalf("Failed to create message: %v", err) +} + +// Send the message +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() + +codec := codec.NewCodec(codec.DefaultConfig()) +if err := codec.SendMessage(ctx, conn, msg); err != nil { + log.Fatalf("Failed to send message: %v", err) +} + +// Receive a message +ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() + +receivedMsg, err := codec.ReadMessage(ctx, conn) +if err != nil { + log.Fatalf("Failed to receive message: %v", err) +} +``` + +### Custom Serialization + +```go +// Set a custom marshaller/unmarshaller +codec.SetMarshaller(func(v any) ([]byte, error) { + // Custom serialization logic + return someCustomSerializer.Marshal(v) +}) + +codec.SetUnmarshaller(func(data []byte, v any) error { + // Custom deserialization logic + return someCustomSerializer.Unmarshal(data, v) +}) +``` + +### Enabling Compression + +```go +// Enable compression globally +codec.EnableCompression(true) + +// Or configure it per codec instance +config := codec.DefaultConfig() +config.EnableCompression = true +codec := codec.NewCodec(config) +``` + +### Enabling Encryption + +```go +// Generate a secure key +key := make([]byte, 32) // 256-bit key +if _, err := rand.Read(key); err != nil { + log.Fatalf("Failed to generate key: %v", err) +} + +// Enable encryption globally +codec.EnableEncryption(true, key) + +// Or configure it per serialization manager +config := codec.DefaultSerializationConfig() +config.EnableEncryption = true +config.EncryptionKey = key +config.PreferredCipher = "chacha20poly1305" // or "aes-gcm" +``` + +### Connection Health Monitoring + +```go +// Create a heartbeat manager +codec := codec.NewCodec(codec.DefaultConfig()) +hm := codec.NewHeartbeatManager(codec, conn) + +// Configure heartbeat +hm.SetInterval(15 * time.Second) +hm.SetTimeout(45 * time.Second) +hm.SetOnFailure(func(err error) { + log.Printf("Heartbeat failure: %v", err) + // Take action like closing connection +}) + +// Start heartbeat monitoring +hm.Start() +defer hm.Stop() +``` + +## Configuration + +The codec behavior can be customized through the `Config` struct: + +```go +config := &codec.Config{ + MaxMessageSize: 32 * 1024 * 1024, // 32MB max message size + MaxHeaderSize: 64 * 1024, // 64KB max header size + MaxQueueLength: 128, // Max queue name length + ReadTimeout: 15 * time.Second, + WriteTimeout: 10 * time.Second, + EnableCompression: true, + BufferPoolSize: 2000, +} + +codec := codec.NewCodec(config) +``` + +## Error Handling + +The codec provides detailed error types for different failure scenarios: + +- `ErrMessageTooLarge`: Message exceeds maximum size +- `ErrInvalidMessage`: Invalid message format +- `ErrInvalidQueue`: Invalid queue name +- `ErrInvalidCommand`: Invalid command +- `ErrConnectionClosed`: Connection closed +- `ErrTimeout`: Operation timeout +- `ErrProtocolMismatch`: Protocol version mismatch +- `ErrFragmentationRequired`: Message requires fragmentation +- `ErrInvalidFragment`: Invalid message fragment +- `ErrFragmentTimeout`: Timed out waiting for fragments +- `ErrFragmentMissing`: Missing fragments in sequence + +Error handling example: + +```go +if err := codec.SendMessage(ctx, conn, msg); err != nil { + if errors.Is(err, codec.ErrMessageTooLarge) { + // Handle message size error + } else if errors.Is(err, codec.ErrTimeout) { + // Handle timeout error + } else { + // Handle other errors + } +} +``` + +## Testing + +The codec package includes testing utilities for validation and performance testing: + +```go +ts := codec.NewCodecTestSuite() + +// Test basic message sending/receiving +msg, _ := codec.NewMessage(consts.CmdPublish, []byte("test"), "test-queue", nil) +if err := ts.SendReceiveTest(msg); err != nil { + log.Fatalf("Test failed: %v", err) +} + +// Test fragmentation/reassembly +largePayload := make([]byte, 20*1024*1024) // 20MB payload +rand.Read(largePayload) // Fill with random data +if err := ts.FragmentationTest(largePayload); err != nil { + log.Fatalf("Fragmentation test failed: %v", err) +} +``` diff --git a/codec/codec.go b/codec/codec.go index 92e7d36..fd9dd7e 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -4,103 +4,492 @@ import ( "bufio" "context" "encoding/binary" + "errors" + "fmt" "io" // added for full reads "net" + "sync" "time" // added for handling deadlines "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/internal/bpool" ) -type Message struct { - Headers map[string]string `msgpack:"h"` - Queue string `msgpack:"q"` - Payload []byte `msgpack:"p"` - Command consts.CMD `msgpack:"c"` +// Protocol version for backward compatibility +const ( + ProtocolVersion = uint8(1) + MaxMessageSize = 64 * 1024 * 1024 // 64MB default limit + MaxHeaderSize = 1024 * 1024 // 1MB header limit + MaxQueueLength = 255 // Max queue name length + FragmentationThreshold = 16 * 1024 * 1024 // Messages larger than 16MB will be fragmented + FragmentSize = 8 * 1024 * 1024 // 8MB fragment size + MaxFragments = 256 // Maximum fragments per message +) + +// Error definitions +var ( + ErrMessageTooLarge = errors.New("message exceeds maximum size") + ErrInvalidMessage = errors.New("invalid message format") + ErrInvalidQueue = errors.New("invalid queue name") + ErrInvalidCommand = errors.New("invalid command") + ErrConnectionClosed = errors.New("connection closed") + ErrTimeout = errors.New("operation timeout") + ErrProtocolMismatch = errors.New("protocol version mismatch") + ErrFragmentationRequired = errors.New("message requires fragmentation") + ErrInvalidFragment = errors.New("invalid message fragment") + ErrFragmentTimeout = errors.New("timed out waiting for fragments") + ErrFragmentMissing = errors.New("missing fragments in sequence") +) + +// Config holds codec configuration +type Config struct { + MaxMessageSize uint32 + MaxHeaderSize uint32 + MaxQueueLength uint8 + ReadTimeout time.Duration + WriteTimeout time.Duration + EnableCompression bool + BufferPoolSize int } -func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message { +// DefaultConfig returns default configuration +func DefaultConfig() *Config { + return &Config{ + MaxMessageSize: MaxMessageSize, + MaxHeaderSize: MaxHeaderSize, + MaxQueueLength: MaxQueueLength, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + EnableCompression: false, + BufferPoolSize: 1000, + } +} + +// MessageType indicates the type of message being sent +type MessageType uint8 + +const ( + MessageTypeStandard MessageType = iota + MessageTypeFragment + MessageTypeHeartbeat + MessageTypeAck + MessageTypeError +) + +// MessageFlag represents various flags that can be set on messages +type MessageFlag uint16 + +const ( + FlagNone MessageFlag = 0 + FlagFragmented MessageFlag = 1 << iota + FlagCompressed + FlagEncrypted + FlagHighPriority + FlagRedelivered + FlagNoAck +) + +// Message represents a protocol message with validation +type Message struct { + Headers map[string]string `msgpack:"h" json:"headers"` + Queue string `msgpack:"q" json:"queue"` + Payload []byte `msgpack:"p" json:"payload"` + Command consts.CMD `msgpack:"c" json:"command"` + Version uint8 `msgpack:"v" json:"version"` + Timestamp int64 `msgpack:"t" json:"timestamp"` + ID string `msgpack:"i" json:"id,omitempty"` + Flags MessageFlag `msgpack:"f" json:"flags"` + Type MessageType `msgpack:"mt" json:"messageType"` + FragmentID uint32 `msgpack:"fid" json:"fragmentId,omitempty"` + Fragments uint16 `msgpack:"fs" json:"fragments,omitempty"` + Sequence uint16 `msgpack:"seq" json:"sequence,omitempty"` +} + +// Codec handles message encoding/decoding with configuration +type Codec struct { + config *Config + mu sync.RWMutex + stats *Stats +} + +// Stats tracks codec statistics +type Stats struct { + MessagesSent uint64 + MessagesReceived uint64 + BytesSent uint64 + BytesReceived uint64 + Errors uint64 + mu sync.RWMutex +} + +// NewCodec creates a new codec with configuration +func NewCodec(config *Config) *Codec { + if config == nil { + config = DefaultConfig() + } + return &Codec{ + config: config, + stats: &Stats{}, + } +} + +// NewMessage creates a validated message +func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) (*Message, error) { + if err := validateCommand(cmd); err != nil { + return nil, err + } + + if err := validateQueue(queue); err != nil { + return nil, err + } + if headers == nil { headers = make(map[string]string) } - return &Message{ - Headers: headers, - Queue: queue, - Command: cmd, - Payload: payload, - } -} -func (m *Message) Serialize() ([]byte, error) { - data, err := Marshal(m) - if err != nil { + if err := validateHeaders(headers); err != nil { return nil, err } + + return &Message{ + Headers: headers, + Queue: queue, + Command: cmd, + Payload: payload, + Version: ProtocolVersion, + Timestamp: time.Now().Unix(), + }, nil +} + +// Validate performs message validation +func (m *Message) Validate(config *Config) error { + if m == nil { + return ErrInvalidMessage + } + + if m.Version != ProtocolVersion { + return ErrProtocolMismatch + } + + if err := validateCommand(m.Command); err != nil { + return err + } + + if err := validateQueue(m.Queue); err != nil { + return err + } + + if err := validateHeaders(m.Headers); err != nil { + return err + } + + if len(m.Payload) > int(config.MaxMessageSize) { + return ErrMessageTooLarge + } + + return nil +} + +// Serialize converts message to bytes with validation +func (m *Message) Serialize() ([]byte, error) { + if m == nil { + return nil, ErrInvalidMessage + } + + data, err := Marshal(m) + if err != nil { + return nil, fmt.Errorf("serialization failed: %w", err) + } + return data, nil } +// Deserialize converts bytes to message with validation func Deserialize(data []byte) (*Message, error) { + if len(data) == 0 { + return nil, ErrInvalidMessage + } + var msg Message if err := Unmarshal(data, &msg); err != nil { - return nil, err + return nil, fmt.Errorf("deserialization failed: %w", err) } + return &msg, nil } -func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error { +// SendMessage sends a message with proper error handling and timeouts +func (c *Codec) SendMessage(ctx context.Context, conn net.Conn, msg *Message) error { + // Check context cancellation before proceeding + if err := ctx.Err(); err != nil { + c.incrementErrors() + return fmt.Errorf("context ended before send: %w", err) + } + + if msg == nil { + return ErrInvalidMessage + } + + // Validate message + if err := msg.Validate(c.config); err != nil { + c.incrementErrors() + return fmt.Errorf("message validation failed: %w", err) + } + + // Check if this is a fragment message, if so handle it directly + if msg.Type == MessageTypeFragment { + return c.sendRawMessage(ctx, conn, msg) + } + + // Handle fragmentation for large messages if needed + if len(msg.Payload) > int(FragmentationThreshold) && msg.Type != MessageTypeFragment { + fm := NewFragmentManager(c, c.config) + defer fm.Stop() + return fm.sendFragmentedMessage(ctx, conn, msg) + } + + // Standard message send path + return c.sendRawMessage(ctx, conn, msg) +} + +// sendRawMessage handles the actual sending of a message or fragment +func (c *Codec) sendRawMessage(ctx context.Context, conn net.Conn, msg *Message) error { + // Serialize message data, err := msg.Serialize() if err != nil { - return err + c.incrementErrors() + return fmt.Errorf("message serialization failed: %w", err) } + + // Check message size + if len(data) > int(c.config.MaxMessageSize) { + c.incrementErrors() + return ErrMessageTooLarge + } + + // Prepare buffer totalLength := 4 + len(data) buffer := bpool.Get() defer bpool.Put(buffer) buffer.Reset() + if cap(buffer.B) < totalLength { buffer.B = make([]byte, totalLength) } else { buffer.B = buffer.B[:totalLength] } + + // Write length prefix and data binary.BigEndian.PutUint32(buffer.B[:4], uint32(len(data))) copy(buffer.B[4:], data) + + // Set timeout + deadline := time.Now().Add(c.config.WriteTimeout) + if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) { + deadline = ctxDeadline + } + + if err := conn.SetWriteDeadline(deadline); err != nil { + c.incrementErrors() + return fmt.Errorf("failed to set write deadline: %w", err) + } + defer conn.SetWriteDeadline(time.Time{}) + + // Write with buffering writer := bufio.NewWriter(conn) - // Set write deadline if context has one - if deadline, ok := ctx.Deadline(); ok { - conn.SetWriteDeadline(deadline) - defer conn.SetWriteDeadline(time.Time{}) + written, err := writer.Write(buffer.B[:totalLength]) + if err != nil { + c.incrementErrors() + return fmt.Errorf("write failed after %d bytes: %w", written, err) } - // Write full data - if _, err := writer.Write(buffer.B[:totalLength]); err != nil { - return err + if err := writer.Flush(); err != nil { + c.incrementErrors() + return fmt.Errorf("flush failed: %w", err) } - return writer.Flush() + + // Update statistics + c.stats.mu.Lock() + c.stats.MessagesSent++ + c.stats.BytesSent += uint64(totalLength) + c.stats.mu.Unlock() + + return nil } -func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) { +// ReadMessage reads a message with proper error handling and timeouts +func (c *Codec) ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) { + // Check context cancellation before proceeding + if err := ctx.Err(); err != nil { + c.incrementErrors() + return nil, fmt.Errorf("context ended before read: %w", err) + } + + // Check context cancellation before proceeding + if err := ctx.Err(); err != nil { + c.incrementErrors() + return nil, fmt.Errorf("context ended before read: %w", err) + } + + // Set timeout + deadline := time.Now().Add(c.config.ReadTimeout) + if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) { + deadline = ctxDeadline + } + + if err := conn.SetReadDeadline(deadline); err != nil { + c.incrementErrors() + return nil, fmt.Errorf("failed to set read deadline: %w", err) + } + defer conn.SetReadDeadline(time.Time{}) + + // Read length prefix lengthBytes := make([]byte, 4) - // Set read deadline if context has one - if deadline, ok := ctx.Deadline(); ok { - conn.SetReadDeadline(deadline) - defer conn.SetReadDeadline(time.Time{}) - } - // Use io.ReadFull to ensure all header bytes are read if _, err := io.ReadFull(conn, lengthBytes); err != nil { - return nil, err + c.incrementErrors() + if errors.Is(err, io.EOF) { + return nil, ErrConnectionClosed + } + return nil, fmt.Errorf("failed to read message length: %w", err) } + length := binary.BigEndian.Uint32(lengthBytes) + + // Validate message size + if length > c.config.MaxMessageSize { + c.incrementErrors() + return nil, ErrMessageTooLarge + } + + if length == 0 { + c.incrementErrors() + return nil, ErrInvalidMessage + } + + // Read message data buffer := bpool.Get() defer bpool.Put(buffer) buffer.Reset() + if cap(buffer.B) < int(length) { buffer.B = make([]byte, length) } else { buffer.B = buffer.B[:length] } - // Read the entire message payload + if _, err := io.ReadFull(conn, buffer.B[:length]); err != nil { - return nil, err + c.incrementErrors() + if errors.Is(err, io.EOF) { + return nil, ErrConnectionClosed + } + return nil, fmt.Errorf("failed to read message data: %w", err) } - return Deserialize(buffer.B[:length]) + + // Deserialize message + msg, err := Deserialize(buffer.B[:length]) + if err != nil { + c.incrementErrors() + return nil, fmt.Errorf("failed to deserialize message: %w", err) + } + + // Validate message + if err := msg.Validate(c.config); err != nil { + c.incrementErrors() + return nil, fmt.Errorf("message validation failed: %w", err) + } + + // Handle message fragments if needed + if msg.Type == MessageTypeFragment || (msg.Flags&FlagFragmented) != 0 { + fm := NewFragmentManager(c, c.config) + reassembled, isFragment, err := fm.processFragment(msg) + if err != nil { + c.incrementErrors() + return nil, fmt.Errorf("fragment processing failed: %w", err) + } + + // If this is a fragment but reassembly isn't complete yet + if isFragment && reassembled == nil { + // Update statistics but return nil with no error to indicate + // the caller should continue reading messages + c.stats.mu.Lock() + c.stats.MessagesReceived++ + c.stats.BytesReceived += uint64(4 + length) + c.stats.mu.Unlock() + + // Read the next fragment + return c.ReadMessage(ctx, conn) + } + + // Use the reassembled message if available + if reassembled != nil { + msg = reassembled + } + } + + // Update statistics + c.stats.mu.Lock() + c.stats.MessagesReceived++ + c.stats.BytesReceived += uint64(4 + length) + c.stats.mu.Unlock() + + return msg, nil +} + +// GetStats returns codec statistics +func (c *Codec) GetStats() Stats { + c.stats.mu.RLock() + defer c.stats.mu.RUnlock() + return *c.stats +} + +// ResetStats resets codec statistics +func (c *Codec) ResetStats() { + c.stats.mu.Lock() + defer c.stats.mu.Unlock() + *c.stats = Stats{} +} + +// Helper functions for validation +func validateCommand(cmd consts.CMD) error { + // Add validation based on your command constants + if cmd < 0 { + return ErrInvalidCommand + } + return nil +} + +func validateQueue(queue string) error { + if len(queue) == 0 || len(queue) > MaxQueueLength { + return ErrInvalidQueue + } + return nil +} + +func validateHeaders(headers map[string]string) error { + totalSize := 0 + for k, v := range headers { + totalSize += len(k) + len(v) + if totalSize > MaxHeaderSize { + return ErrMessageTooLarge + } + } + return nil +} + +func (c *Codec) incrementErrors() { + c.stats.mu.Lock() + c.stats.Errors++ + c.stats.mu.Unlock() +} + +// SendMessage Backward compatibility functions +func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error { + codec := NewCodec(DefaultConfig()) + return codec.SendMessage(ctx, conn, msg) +} + +func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) { + codec := NewCodec(DefaultConfig()) + return codec.ReadMessage(ctx, conn) } diff --git a/codec/fragment.go b/codec/fragment.go new file mode 100644 index 0000000..edd0600 --- /dev/null +++ b/codec/fragment.go @@ -0,0 +1,282 @@ +package codec + +import ( + "context" + "encoding/binary" + "fmt" + "hash/crc32" + "net" + "sync" + "time" + + "github.com/oarkflow/mq/consts" +) + +// FragmentManager handles message fragmentation and reassembly +type FragmentManager struct { + codec *Codec + config *Config + fragmentStore map[string]*fragmentAssembly + mu sync.RWMutex + cleanupInterval time.Duration + cleanupTimer *time.Timer + cleanupChan chan struct{} + reassemblyTimeout time.Duration +} + +// fragmentAssembly holds fragments for a specific message +type fragmentAssembly struct { + fragments map[uint16][]byte + totalSize int + createdAt time.Time + totalCount uint16 + messageType MessageType + queue string + command consts.CMD + headers map[string]string + id string +} + +// NewFragmentManager creates a new fragment manager +func NewFragmentManager(codec *Codec, config *Config) *FragmentManager { + fm := &FragmentManager{ + codec: codec, + config: config, + fragmentStore: make(map[string]*fragmentAssembly), + cleanupInterval: time.Minute, + reassemblyTimeout: 5 * time.Minute, + } + + fm.startCleanupTimer() + return fm +} + +// startCleanupTimer starts a timer to clean up old fragments +func (fm *FragmentManager) startCleanupTimer() { + fm.cleanupChan = make(chan struct{}) + fm.cleanupTimer = time.NewTimer(fm.cleanupInterval) + + go func() { + for { + select { + case <-fm.cleanupTimer.C: + fm.cleanupExpiredFragments() + fm.cleanupTimer.Reset(fm.cleanupInterval) + case <-fm.cleanupChan: + return + } + } + }() +} + +// Stop stops the fragment manager +func (fm *FragmentManager) Stop() { + if fm.cleanupTimer != nil { + fm.cleanupTimer.Stop() + } + + if fm.cleanupChan != nil { + close(fm.cleanupChan) + } +} + +// cleanupExpiredFragments removes expired fragment assemblies +func (fm *FragmentManager) cleanupExpiredFragments() { + fm.mu.Lock() + defer fm.mu.Unlock() + + now := time.Now() + for id, assembly := range fm.fragmentStore { + if now.Sub(assembly.createdAt) > fm.reassemblyTimeout { + delete(fm.fragmentStore, id) + } + } +} + +// fragmentMessage splits a large message into fragments +func (fm *FragmentManager) fragmentMessage(ctx context.Context, msg *Message) ([]*Message, error) { + if msg == nil { + return nil, ErrInvalidMessage + } + + // Check if fragmentation is needed + if len(msg.Payload) <= int(FragmentationThreshold) { + return []*Message{msg}, nil + } + + // Calculate how many fragments we need + fragmentCount := (len(msg.Payload) + int(FragmentSize) - 1) / int(FragmentSize) + if fragmentCount > int(MaxFragments) { + return nil, fmt.Errorf("%w: would require %d fragments, maximum is %d", + ErrMessageTooLarge, fragmentCount, MaxFragments) + } + + // Generate a fragment ID using a hash of message content + timestamp + idBytes := []byte(msg.ID) + timestampBytes := make([]byte, 8) + binary.BigEndian.PutUint64(timestampBytes, uint64(msg.Timestamp)) + hashInput := append(idBytes, timestampBytes...) + hashInput = append(hashInput, msg.Payload[:min(1024, len(msg.Payload))]...) + fragmentID := crc32.ChecksumIEEE(hashInput) + + // Create fragment messages + fragments := make([]*Message, fragmentCount) + for i := 0; i < fragmentCount; i++ { + // Calculate fragment payload boundaries + start := i * int(FragmentSize) + end := min((i+1)*int(FragmentSize), len(msg.Payload)) + fragmentPayload := msg.Payload[start:end] + + // Create fragment message + fragment := &Message{ + Headers: copyHeaders(msg.Headers), + Queue: msg.Queue, + Command: msg.Command, + Version: msg.Version, + Timestamp: msg.Timestamp, + ID: msg.ID, + Type: MessageTypeFragment, + Flags: msg.Flags | FlagFragmented, + FragmentID: fragmentID, + Fragments: uint16(fragmentCount), + Sequence: uint16(i), + Payload: fragmentPayload, + } + + fragments[i] = fragment + } + + return fragments, nil +} + +// sendFragmentedMessage sends a large message as multiple fragments +func (fm *FragmentManager) sendFragmentedMessage(ctx context.Context, conn net.Conn, msg *Message) error { + fragments, err := fm.fragmentMessage(ctx, msg) + if err != nil { + return err + } + + // If no fragmentation was needed, send as normal + if len(fragments) == 1 && fragments[0].Type != MessageTypeFragment { + return fm.codec.SendMessage(ctx, conn, fragments[0]) + } + + // Send each fragment + for _, fragment := range fragments { + if err := fm.codec.SendMessage(ctx, conn, fragment); err != nil { + return fmt.Errorf("failed to send fragment %d/%d: %w", + fragment.Sequence+1, fragment.Fragments, err) + } + } + + return nil +} + +// processFragment processes a fragment and attempts reassembly +func (fm *FragmentManager) processFragment(msg *Message) (*Message, bool, error) { + if msg == nil || msg.Type != MessageTypeFragment || msg.FragmentID == 0 { + return msg, false, nil // Not a fragment, return as is + } + + // Generate a unique key for this fragmented message + key := fmt.Sprintf("%s-%d-%s", msg.ID, msg.FragmentID, msg.Queue) + + fm.mu.Lock() + defer fm.mu.Unlock() + + // Check if we already have an assembly for this message + assembly, exists := fm.fragmentStore[key] + if !exists { + // Create a new assembly + assembly = &fragmentAssembly{ + fragments: make(map[uint16][]byte), + createdAt: time.Now(), + totalCount: msg.Fragments, + messageType: MessageTypeStandard, + queue: msg.Queue, + command: msg.Command, + headers: copyHeaders(msg.Headers), + id: msg.ID, + } + fm.fragmentStore[key] = assembly + } + + // Store this fragment + assembly.fragments[msg.Sequence] = msg.Payload + assembly.totalSize += len(msg.Payload) + + // Check if we have all fragments + if len(assembly.fragments) == int(assembly.totalCount) { + // We have all fragments, reassemble the message + reassembled, err := fm.reassembleMessage(key, assembly) + if err != nil { + return nil, true, err + } + return reassembled, true, nil + } + + // Still waiting for more fragments + return nil, true, nil +} + +// reassembleMessage combines fragments into the original message +func (fm *FragmentManager) reassembleMessage(key string, assembly *fragmentAssembly) (*Message, error) { + // Remove the assembly from the store + delete(fm.fragmentStore, key) + + // Check if we have all fragments + if len(assembly.fragments) != int(assembly.totalCount) { + return nil, ErrFragmentMissing + } + + // Allocate space for the full payload + fullPayload := make([]byte, assembly.totalSize) + + // Combine fragments in order + offset := 0 + for i := uint16(0); i < assembly.totalCount; i++ { + fragment, exists := assembly.fragments[i] + if !exists { + return nil, fmt.Errorf("%w: missing fragment %d", ErrFragmentMissing, i) + } + + copy(fullPayload[offset:], fragment) + offset += len(fragment) + } + + // Create the reassembled message + msg := &Message{ + Headers: assembly.headers, + Queue: assembly.queue, + Command: assembly.command, + Version: ProtocolVersion, + Timestamp: time.Now().Unix(), + ID: assembly.id, + Type: assembly.messageType, + Flags: FlagNone, // Clear fragmentation flag + Payload: fullPayload, + } + + return msg, nil +} + +// Helper function to make a copy of headers to avoid shared references +func copyHeaders(src map[string]string) map[string]string { + if src == nil { + return nil + } + + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +// min returns the smaller of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/codec/heartbeat.go b/codec/heartbeat.go new file mode 100644 index 0000000..af6bf31 --- /dev/null +++ b/codec/heartbeat.go @@ -0,0 +1,193 @@ +package codec + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/oarkflow/mq/consts" +) + +// HeartbeatManager manages heartbeat messages for connection health monitoring +type HeartbeatManager struct { + codec *Codec + interval time.Duration + timeout time.Duration + conn net.Conn + stopChan chan struct{} + lastHeartbeat atomic.Int64 + lastReceived atomic.Int64 + onFailure func(error) + mu sync.RWMutex + isRunning bool + heartbeatsSent uint64 + heartbeatsRecv uint64 + failedHeartbeats uint64 +} + +// NewHeartbeatManager creates a new heartbeat manager +func NewHeartbeatManager(codec *Codec, conn net.Conn) *HeartbeatManager { + return &HeartbeatManager{ + codec: codec, + interval: 30 * time.Second, + timeout: 90 * time.Second, // 3x interval + conn: conn, + stopChan: make(chan struct{}), + } +} + +// SetInterval sets the heartbeat interval +func (hm *HeartbeatManager) SetInterval(interval time.Duration) { + hm.mu.Lock() + defer hm.mu.Unlock() + hm.interval = interval +} + +// SetTimeout sets the heartbeat timeout +func (hm *HeartbeatManager) SetTimeout(timeout time.Duration) { + hm.mu.Lock() + defer hm.mu.Unlock() + hm.timeout = timeout +} + +// SetOnFailure sets the callback function for heartbeat failures +func (hm *HeartbeatManager) SetOnFailure(fn func(error)) { + hm.mu.Lock() + defer hm.mu.Unlock() + hm.onFailure = fn +} + +// Start starts the heartbeat monitoring +func (hm *HeartbeatManager) Start() { + hm.mu.Lock() + defer hm.mu.Unlock() + + if hm.isRunning { + return + } + + hm.isRunning = true + hm.lastHeartbeat.Store(time.Now().Unix()) + hm.lastReceived.Store(time.Now().Unix()) + + go hm.sendHeartbeats() + go hm.monitorHeartbeats() +} + +// Stop stops the heartbeat monitoring +func (hm *HeartbeatManager) Stop() { + hm.mu.Lock() + defer hm.mu.Unlock() + + if !hm.isRunning { + return + } + + hm.isRunning = false + close(hm.stopChan) + // Create a new stop channel for future use + hm.stopChan = make(chan struct{}) +} + +// IsRunning returns whether the heartbeat manager is running +func (hm *HeartbeatManager) IsRunning() bool { + hm.mu.RLock() + defer hm.mu.RUnlock() + return hm.isRunning +} + +// GetStats returns heartbeat statistics +func (hm *HeartbeatManager) GetStats() map[string]uint64 { + return map[string]uint64{ + "sent": atomic.LoadUint64(&hm.heartbeatsSent), + "received": atomic.LoadUint64(&hm.heartbeatsRecv), + "failed": atomic.LoadUint64(&hm.failedHeartbeats), + } +} + +// RecordHeartbeat records a received heartbeat +func (hm *HeartbeatManager) RecordHeartbeat() { + hm.lastReceived.Store(time.Now().Unix()) + atomic.AddUint64(&hm.heartbeatsRecv, 1) +} + +// sendHeartbeats sends periodic heartbeat messages +func (hm *HeartbeatManager) sendHeartbeats() { + ticker := time.NewTicker(hm.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := hm.sendHeartbeat(); err != nil { + atomic.AddUint64(&hm.failedHeartbeats, 1) + hm.mu.RLock() + onFailure := hm.onFailure + hm.mu.RUnlock() + + if onFailure != nil { + onFailure(err) + } + } + case <-hm.stopChan: + return + } + } +} + +// monitorHeartbeats monitors the heartbeat health +func (hm *HeartbeatManager) monitorHeartbeats() { + ticker := time.NewTicker(hm.interval / 2) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + now := time.Now().Unix() + lastReceived := hm.lastReceived.Load() + + // Check if we've exceeded the timeout + if now-lastReceived > int64(hm.timeout.Seconds()) { + err := fmt.Errorf("heartbeat timeout: last received %v seconds ago", + now-lastReceived) + + atomic.AddUint64(&hm.failedHeartbeats, 1) + hm.mu.RLock() + onFailure := hm.onFailure + hm.mu.RUnlock() + + if onFailure != nil { + onFailure(err) + } + } + case <-hm.stopChan: + return + } + } +} + +// sendHeartbeat sends a single heartbeat message +func (hm *HeartbeatManager) sendHeartbeat() error { + msg := &Message{ + Type: MessageTypeHeartbeat, + Command: consts.CMD(0), // Use appropriate heartbeat command from your consts + Version: ProtocolVersion, + Timestamp: time.Now().Unix(), + Headers: map[string]string{"type": "heartbeat"}, + Payload: []byte{}, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := hm.codec.SendMessage(ctx, hm.conn, msg) + if err == nil { + hm.lastHeartbeat.Store(time.Now().Unix()) + atomic.AddUint64(&hm.heartbeatsSent, 1) + } + + return err +} diff --git a/codec/serializer.go b/codec/serializer.go index c9cf90f..9899663 100644 --- a/codec/serializer.go +++ b/codec/serializer.go @@ -1,39 +1,418 @@ package codec import ( + "bytes" + "compress/gzip" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "sync" + "time" + "github.com/oarkflow/json" + "golang.org/x/crypto/chacha20poly1305" ) -type MarshallerFunc func(v any) ([]byte, error) +// Error definitions for serialization +var ( + ErrSerializationFailed = errors.New("serialization failed") + ErrDeserializationFailed = errors.New("deserialization failed") + ErrCompressionFailed = errors.New("compression failed") + ErrDecompressionFailed = errors.New("decompression failed") + ErrEncryptionFailed = errors.New("encryption failed") + ErrDecryptionFailed = errors.New("decryption failed") + ErrInvalidKey = errors.New("invalid encryption key") +) -type UnmarshallerFunc func(data []byte, v any) error +// ContentType represents the content type of serialized data +type ContentType string + +const ( + ContentTypeJSON ContentType = "application/json" + ContentTypeMsgPack ContentType = "application/msgpack" + ContentTypeCBOR ContentType = "application/cbor" +) + +// Marshaller interface for pluggable serialization +type Marshaller interface { + Marshal(v any) ([]byte, error) + ContentType() ContentType +} + +// Unmarshaller interface for pluggable deserialization +type Unmarshaller interface { + Unmarshal(data []byte, v any) error + ContentType() ContentType +} + +// MarshallerFunc adapter +type MarshallerFunc func(v any) ([]byte, error) func (f MarshallerFunc) Marshal(v any) ([]byte, error) { return f(v) } +func (f MarshallerFunc) ContentType() ContentType { + return ContentTypeJSON +} + +// UnmarshallerFunc adapter +type UnmarshallerFunc func(data []byte, v any) error + func (f UnmarshallerFunc) Unmarshal(data []byte, v any) error { return f(data, v) } -var defaultMarshaller MarshallerFunc = json.Marshal - -var defaultUnmarshaller UnmarshallerFunc = func(data []byte, v any) error { - return json.Unmarshal(data, v) +func (f UnmarshallerFunc) ContentType() ContentType { + return ContentTypeJSON } +// SerializationConfig holds serialization configuration +type SerializationConfig struct { + EnableCompression bool + CompressionLevel int + MaxCompressionRatio float64 + EnableEncryption bool + EncryptionKey []byte + PreferredCipher string // "chacha20poly1305" or "aes-gcm" +} + +// DefaultSerializationConfig returns default configuration +func DefaultSerializationConfig() *SerializationConfig { + return &SerializationConfig{ + EnableCompression: false, + CompressionLevel: gzip.DefaultCompression, + MaxCompressionRatio: 0.8, // Only compress if we save at least 20% + EnableEncryption: false, + PreferredCipher: "chacha20poly1305", + } +} + +// SerializationManager manages serialization with configuration +type SerializationManager struct { + marshaller Marshaller + unmarshaller Unmarshaller + config *SerializationConfig + mu sync.RWMutex + cachedCipher cipher.AEAD + cipherMu sync.Mutex +} + +// NewSerializationManager creates a new serialization manager +func NewSerializationManager(config *SerializationConfig) *SerializationManager { + if config == nil { + config = DefaultSerializationConfig() + } + + return &SerializationManager{ + marshaller: MarshallerFunc(json.Marshal), + unmarshaller: UnmarshallerFunc(func(data []byte, v any) error { + return json.Unmarshal(data, v) + }), + config: config, + } +} + +// SetMarshaller sets custom marshaller +func (sm *SerializationManager) SetMarshaller(marshaller Marshaller) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.marshaller = marshaller +} + +// SetUnmarshaller sets custom unmarshaller +func (sm *SerializationManager) SetUnmarshaller(unmarshaller Unmarshaller) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.unmarshaller = unmarshaller +} + +// SetEncryptionKey sets the encryption key +func (sm *SerializationManager) SetEncryptionKey(key []byte) error { + if sm.config.EnableEncryption && len(key) == 0 { + return ErrInvalidKey + } + + sm.mu.Lock() + defer sm.mu.Unlock() + + sm.config.EncryptionKey = key + // Clear the cached cipher so it will be recreated with the new key + sm.cipherMu.Lock() + defer sm.cipherMu.Unlock() + sm.cachedCipher = nil + + return nil +} + +// Marshal serializes data with optional compression and encryption +func (sm *SerializationManager) Marshal(v any) ([]byte, error) { + sm.mu.RLock() + marshaller := sm.marshaller + config := sm.config + sm.mu.RUnlock() + + // Serialize the data + data, err := marshaller.Marshal(v) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrSerializationFailed, err) + } + + // Apply compression if enabled and beneficial + if config.EnableCompression && len(data) > 256 { // Only compress larger payloads + compressed, err := sm.compress(data) + if err != nil { + // Continue with uncompressed data on error + // but don't return error to allow for graceful degradation + } else if float64(len(compressed))/float64(len(data)) <= config.MaxCompressionRatio { + // Only use compression if it provides significant benefit + data = compressed + } + } + + // Apply encryption if enabled + if config.EnableEncryption && len(config.EncryptionKey) > 0 { + encrypted, err := sm.encrypt(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err) + } + data = encrypted + } + + return data, nil +} + +// Unmarshal deserializes data with optional decompression and decryption +func (sm *SerializationManager) Unmarshal(data []byte, v any) error { + if len(data) == 0 { + return fmt.Errorf("%w: empty data", ErrDeserializationFailed) + } + + sm.mu.RLock() + unmarshaller := sm.unmarshaller + config := sm.config + sm.mu.RUnlock() + + // Apply decryption if enabled + if config.EnableEncryption && len(config.EncryptionKey) > 0 { + decrypted, err := sm.decrypt(data) + if err != nil { + return fmt.Errorf("%w: %v", ErrDecryptionFailed, err) + } + data = decrypted + } + + // Try decompression if enabled + if config.EnableCompression && sm.isCompressed(data) { + decompressed, err := sm.decompress(data) + if err != nil { + return fmt.Errorf("%w: %v", ErrDecompressionFailed, err) + } + data = decompressed + } + + // Deserialize the data + if err := unmarshaller.Unmarshal(data, v); err != nil { + return fmt.Errorf("%w: %v", ErrDeserializationFailed, err) + } + + return nil +} + +// compress compresses data using gzip +func (sm *SerializationManager) compress(data []byte) ([]byte, error) { + var buf bytes.Buffer + + // Add compression marker + buf.WriteByte(0x1f) // gzip magic number + + writer, err := gzip.NewWriterLevel(&buf, sm.config.CompressionLevel) + if err != nil { + return nil, err + } + + if _, err := writer.Write(data); err != nil { + writer.Close() + return nil, err + } + + if err := writer.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// decompress decompresses gzip data +func (sm *SerializationManager) decompress(data []byte) ([]byte, error) { + if len(data) < 1 { + return nil, fmt.Errorf("invalid compressed data") + } + + // Remove compression marker + if data[0] != 0x1f { + return data, nil // Not compressed + } + + reader, err := gzip.NewReader(bytes.NewReader(data[1:])) + if err != nil { + return nil, err + } + defer reader.Close() + + var buf bytes.Buffer + if _, err := buf.ReadFrom(reader); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// isCompressed checks if data is compressed +func (sm *SerializationManager) isCompressed(data []byte) bool { + return len(data) > 0 && data[0] == 0x1f +} + +// getCipher returns a cipher.AEAD instance for encryption/decryption +func (sm *SerializationManager) getCipher() (cipher.AEAD, error) { + sm.cipherMu.Lock() + defer sm.cipherMu.Unlock() + + if sm.cachedCipher != nil { + return sm.cachedCipher, nil + } + + var aead cipher.AEAD + var err error + + sm.mu.RLock() + key := sm.config.EncryptionKey + preferredCipher := sm.config.PreferredCipher + sm.mu.RUnlock() + + switch preferredCipher { + case "chacha20poly1305": + aead, err = chacha20poly1305.New(key) + case "aes-gcm": + block, e := aes.NewCipher(key) + if e != nil { + err = e + break + } + aead, err = cipher.NewGCM(block) + default: + // Default to ChaCha20-Poly1305 + aead, err = chacha20poly1305.New(key) + } + + if err != nil { + return nil, err + } + + sm.cachedCipher = aead + return aead, nil +} + +// encrypt encrypts data using authenticated encryption +func (sm *SerializationManager) encrypt(data []byte) ([]byte, error) { + aead, err := sm.getCipher() + if err != nil { + return nil, err + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + // Add timestamp to associated data for replay protection + timestampBytes := make([]byte, 8) + binary.BigEndian.PutUint64(timestampBytes, uint64(time.Now().Unix())) + + // Encrypt and authenticate + ciphertext := aead.Seal(nil, nonce, data, timestampBytes) + + // Prepend nonce and timestamp to ciphertext + result := make([]byte, len(nonce)+len(timestampBytes)+len(ciphertext)) + copy(result, nonce) + copy(result[len(nonce):], timestampBytes) + copy(result[len(nonce)+len(timestampBytes):], ciphertext) + + return result, nil +} + +// decrypt decrypts data using authenticated decryption +func (sm *SerializationManager) decrypt(data []byte) ([]byte, error) { + aead, err := sm.getCipher() + if err != nil { + return nil, err + } + + // Extract nonce + nonceSize := aead.NonceSize() + if len(data) < nonceSize+8 { // 8 bytes for timestamp + return nil, fmt.Errorf("ciphertext too short") + } + + nonce := data[:nonceSize] + timestampBytes := data[nonceSize : nonceSize+8] + ciphertext := data[nonceSize+8:] + + // Decrypt and verify + plaintext, err := aead.Open(nil, nonce, ciphertext, timestampBytes) + if err != nil { + return nil, err + } + + return plaintext, nil +} + +// Global serialization manager instance +var globalSerializationManager = NewSerializationManager(nil) +var globalSerializationManagerMu sync.RWMutex + +// Global functions for backward compatibility func SetMarshaller(marshaller MarshallerFunc) { - defaultMarshaller = marshaller + globalSerializationManagerMu.Lock() + defer globalSerializationManagerMu.Unlock() + globalSerializationManager.SetMarshaller(marshaller) } func SetUnmarshaller(unmarshaller UnmarshallerFunc) { - defaultUnmarshaller = unmarshaller + globalSerializationManagerMu.Lock() + defer globalSerializationManagerMu.Unlock() + globalSerializationManager.SetUnmarshaller(unmarshaller) +} + +func EnableCompression(enable bool) { + globalSerializationManagerMu.Lock() + defer globalSerializationManagerMu.Unlock() + globalSerializationManager.config.EnableCompression = enable +} + +func EnableEncryption(enable bool, key []byte) error { + globalSerializationManagerMu.Lock() + defer globalSerializationManagerMu.Unlock() + globalSerializationManager.config.EnableEncryption = enable + if enable { + return globalSerializationManager.SetEncryptionKey(key) + } + return nil } func Marshal(v any) ([]byte, error) { - return defaultMarshaller(v) + globalSerializationManagerMu.RLock() + defer globalSerializationManagerMu.RUnlock() + return globalSerializationManager.Marshal(v) } func Unmarshal(data []byte, v any) error { - return defaultUnmarshaller(data, v) + globalSerializationManagerMu.RLock() + defer globalSerializationManagerMu.RUnlock() + return globalSerializationManager.Unmarshal(data, v) } diff --git a/codec/testing.go b/codec/testing.go new file mode 100644 index 0000000..61e0c00 --- /dev/null +++ b/codec/testing.go @@ -0,0 +1,210 @@ +package codec + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/oarkflow/mq/consts" +) + +// MockConn implements net.Conn for testing +type MockConn struct { + ReadBuffer *bytes.Buffer + WriteBuffer *bytes.Buffer + ReadDelay time.Duration + WriteDelay time.Duration + IsClosed bool + ReadErr error + WriteErr error + mu sync.Mutex +} + +// NewMockConn creates a new mock connection +func NewMockConn() *MockConn { + return &MockConn{ + ReadBuffer: bytes.NewBuffer(nil), + WriteBuffer: bytes.NewBuffer(nil), + } +} + +// Read implements the net.Conn Read method +func (m *MockConn) Read(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.IsClosed { + return 0, errors.New("connection closed") + } + + if m.ReadErr != nil { + return 0, m.ReadErr + } + + if m.ReadDelay > 0 { + time.Sleep(m.ReadDelay) + } + + return m.ReadBuffer.Read(b) +} + +// Write implements the net.Conn Write method +func (m *MockConn) Write(b []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.IsClosed { + return 0, errors.New("connection closed") + } + + if m.WriteErr != nil { + return 0, m.WriteErr + } + + if m.WriteDelay > 0 { + time.Sleep(m.WriteDelay) + } + + return m.WriteBuffer.Write(b) +} + +// Close implements the net.Conn Close method +func (m *MockConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.IsClosed = true + return nil +} + +// LocalAddr implements the net.Conn LocalAddr method +func (m *MockConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} +} + +// RemoteAddr implements the net.Conn RemoteAddr method +func (m *MockConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} +} + +// SetDeadline implements the net.Conn SetDeadline method +func (m *MockConn) SetDeadline(t time.Time) error { + return nil +} + +// SetReadDeadline implements the net.Conn SetReadDeadline method +func (m *MockConn) SetReadDeadline(t time.Time) error { + return nil +} + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method +func (m *MockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// CodecTestSuite provides utilities for testing the codec +type CodecTestSuite struct { + Codec *Codec + Config *Config +} + +// NewCodecTestSuite creates a new codec test suite +func NewCodecTestSuite() *CodecTestSuite { + config := DefaultConfig() + // Set smaller timeouts for testing + config.ReadTimeout = 500 * time.Millisecond + config.WriteTimeout = 500 * time.Millisecond + + return &CodecTestSuite{ + Codec: NewCodec(config), + Config: config, + } +} + +// SendReceiveTest tests sending and receiving a message +func (ts *CodecTestSuite) SendReceiveTest(msg *Message) error { + conn := NewMockConn() + + // Send the message + ctx := context.Background() + if err := ts.Codec.SendMessage(ctx, conn, msg); err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + // Move written data to read buffer to simulate network transport + conn.ReadBuffer.Write(conn.WriteBuffer.Bytes()) + conn.WriteBuffer.Reset() + + // Receive the message + received, err := ts.Codec.ReadMessage(ctx, conn) + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + + // Validate the message + if received.Command != msg.Command { + return fmt.Errorf("command mismatch: got %v, want %v", received.Command, msg.Command) + } + + if received.Queue != msg.Queue { + return fmt.Errorf("queue mismatch: got %v, want %v", received.Queue, msg.Queue) + } + + if !bytes.Equal(received.Payload, msg.Payload) { + return fmt.Errorf("payload mismatch: got %d bytes, want %d bytes", len(received.Payload), len(msg.Payload)) + } + + return nil +} + +// FragmentationTest tests the fragmentation and reassembly of large messages +func (ts *CodecTestSuite) FragmentationTest(payload []byte) error { + msg := &Message{ + Command: consts.CMD(1), // Use appropriate command from your consts + Queue: "test_queue", + Headers: map[string]string{"test": "header"}, + Payload: payload, + Version: ProtocolVersion, + Timestamp: time.Now().Unix(), + ID: "test-message-id", + } + + conn := NewMockConn() + + // Configure fragmentation + fm := NewFragmentManager(ts.Codec, ts.Config) + + // Send the fragmented message + ctx := context.Background() + if err := fm.sendFragmentedMessage(ctx, conn, msg); err != nil { + return fmt.Errorf("failed to send fragmented message: %w", err) + } + + // Move written data to read buffer to simulate network transport + conn.ReadBuffer.Write(conn.WriteBuffer.Bytes()) + conn.WriteBuffer.Reset() + + // Receive and reassemble the message + received, err := ts.Codec.ReadMessage(ctx, conn) + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + + // Validate the reassembled message + if received.Command != msg.Command { + return fmt.Errorf("command mismatch: got %v, want %v", received.Command, msg.Command) + } + + if received.Queue != msg.Queue { + return fmt.Errorf("queue mismatch: got %v, want %v", received.Queue, msg.Queue) + } + + if !bytes.Equal(received.Payload, msg.Payload) { + return fmt.Errorf("payload mismatch: got %d bytes, want %d bytes", len(received.Payload), len(msg.Payload)) + } + + return nil +} diff --git a/config/production.json b/config/production.json deleted file mode 100644 index c58d1a3..0000000 --- a/config/production.json +++ /dev/null @@ -1,99 +0,0 @@ -{ - "broker": { - "address": "localhost", - "port": 8080, - "max_connections": 1000, - "connection_timeout": "5s", - "read_timeout": "300s", - "write_timeout": "30s", - "idle_timeout": "600s", - "keep_alive": true, - "keep_alive_period": "60s", - "max_queue_depth": 10000, - "enable_dead_letter": true, - "dead_letter_max_retries": 3 - }, - "consumer": { - "enable_http_api": true, - "max_retries": 5, - "initial_delay": "2s", - "max_backoff": "30s", - "jitter_percent": 0.5, - "batch_size": 10, - "prefetch_count": 100, - "auto_ack": false, - "requeue_on_failure": true - }, - "publisher": { - "enable_http_api": true, - "max_retries": 3, - "initial_delay": "1s", - "max_backoff": "10s", - "confirm_delivery": true, - "publish_timeout": "5s", - "connection_pool_size": 10 - }, - "pool": { - "queue_size": 1000, - "max_workers": 20, - "max_memory_load": 1073741824, - "idle_timeout": "300s", - "graceful_shutdown_timeout": "30s", - "task_timeout": "60s", - "enable_metrics": true, - "enable_diagnostics": true - }, - "security": { - "enable_tls": false, - "tls_cert_path": "./certs/server.crt", - "tls_key_path": "./certs/server.key", - "tls_ca_path": "./certs/ca.crt", - "enable_auth": false, - "auth_provider": "jwt", - "jwt_secret": "your-secret-key", - "enable_encryption": false, - "encryption_key": "32-byte-encryption-key-here!!" - }, - "monitoring": { - "metrics_port": 9090, - "health_check_port": 9091, - "enable_metrics": true, - "enable_health_checks": true, - "metrics_interval": "10s", - "health_check_interval": "30s", - "retention_period": "24h", - "enable_tracing": true, - "jaeger_endpoint": "http://localhost:14268/api/traces" - }, - "persistence": { - "enable": true, - "provider": "postgres", - "connection_string": "postgres://user:password@localhost:5432/mq_db?sslmode=disable", - "max_connections": 50, - "connection_timeout": "30s", - "enable_migrations": true, - "backup_enabled": true, - "backup_interval": "6h" - }, - "clustering": { - "enable": false, - "node_id": "node-1", - "cluster_name": "mq-cluster", - "peers": [ ], - "election_timeout": "5s", - "heartbeat_interval": "1s", - "enable_auto_discovery": false, - "discovery_port": 7946 - }, - "rate_limit": { - "broker_rate": 1000, - "broker_burst": 100, - "consumer_rate": 500, - "consumer_burst": 50, - "publisher_rate": 200, - "publisher_burst": 20, - "global_rate": 2000, - "global_burst": 200 - }, - "last_updated": "2025-07-29T00:00:00Z" -} diff --git a/consumer.go b/consumer.go index 181f0b2..01f4788 100644 --- a/consumer.go +++ b/consumer.go @@ -152,7 +152,10 @@ func (c *Consumer) Metrics() Metrics { func (c *Consumer) subscribe(ctx context.Context, queue string) error { headers := HeadersWithConsumerID(ctx, c.id) - msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers) + msg, err := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers) + if err != nil { + return fmt.Errorf("error creating subscribe message: %v", err) + } if err := c.send(ctx, c.conn, msg); err != nil { return fmt.Errorf("error while trying to subscribe: %v", err) } @@ -207,7 +210,14 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) { headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue) taskID, _ := jsonparser.GetString(msg.Payload, "id") - reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) + reply, err := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) + if err != nil { + c.logger.Error("Failed to create MESSAGE_ACK", + logger.Field{Key: "queue", Value: msg.Queue}, + logger.Field{Key: "task_id", Value: taskID}, + logger.Field{Key: "error", Value: err.Error()}) + return + } // Send with timeout to avoid blocking sendCtx, cancel := context.WithTimeout(ctx, 10*time.Second) @@ -375,7 +385,14 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error { } } bt, _ := json.Marshal(result) - reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers) + reply, err := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers) + if err != nil { + c.logger.Error("Failed to create MESSAGE_RESPONSE", + logger.Field{Key: "topic", Value: result.Topic}, + logger.Field{Key: "task_id", Value: result.TaskID}, + logger.Field{Key: "error", Value: err.Error()}) + return + } sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -395,7 +412,14 @@ func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, er // Send deny message asynchronously to avoid blocking go func() { headers := HeadersWithConsumerID(ctx, c.id) - reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) + reply, err := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) + if err != nil { + c.logger.Error("Failed to create MESSAGE_DENY", + logger.Field{Key: "queue", Value: queue}, + logger.Field{Key: "task_id", Value: taskID}, + logger.Field{Key: "error", Value: err.Error()}) + return + } sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -820,7 +844,10 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error { headers := HeadersWithConsumerID(ctx, c.id) - msg := codec.NewMessage(cmd, nil, c.queue, headers) + msg, err := codec.NewMessage(cmd, nil, c.queue, headers) + if err != nil { + return fmt.Errorf("error creating operation message: %v", err) + } return c.send(ctx, c.conn, msg) } diff --git a/examples/email_notification_dag.go b/examples/email_notification_dag.go index 1259b1c..89cde1a 100644 --- a/examples/email_notification_dag.go +++ b/examples/email_notification_dag.go @@ -10,6 +10,7 @@ import ( "github.com/oarkflow/json" "github.com/oarkflow/mq/dag" + "github.com/oarkflow/mq/utils" "github.com/oarkflow/jet" @@ -19,7 +20,7 @@ import ( func main() { flow := dag.NewDAG("Email Notification System", "email-notification", func(taskID string, result mq.Result) { - fmt.Printf("Email notification workflow completed for task %s: %s\n", taskID, string(result.Payload)) + fmt.Printf("Email notification workflow completed for task %s: %s\n", taskID, string(utils.RemoveRecursiveFromJSON(result.Payload, "html_content"))) }) // Add workflow nodes diff --git a/examples/enhanced_dag_demo.go b/examples/enhanced_dag_demo.go deleted file mode 100644 index acda8d1..0000000 --- a/examples/enhanced_dag_demo.go +++ /dev/null @@ -1,557 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net/http" - "os" - "os/signal" - "syscall" - "time" - - "github.com/oarkflow/mq" - "github.com/oarkflow/mq/dag" - "github.com/oarkflow/mq/logger" -) - -// ExampleProcessor demonstrates a custom processor with debugging -type ExampleProcessor struct { - name string - tags []string -} - -func NewExampleProcessor(name string) *ExampleProcessor { - return &ExampleProcessor{ - name: name, - tags: []string{"example", "demo"}, - } -} - -func (p *ExampleProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result { - // Simulate processing time - time.Sleep(100 * time.Millisecond) - - // Add some example processing logic - var data map[string]interface{} - if err := task.UnmarshalPayload(&data); err != nil { - return mq.Result{Error: err} - } - - // Process the data - data["processed_by"] = p.name - data["processed_at"] = time.Now() - - payload, _ := task.MarshalPayload(data) - return mq.Result{Payload: payload} -} - -func (p *ExampleProcessor) SetConfig(payload dag.Payload) {} -func (p *ExampleProcessor) SetTags(tags ...string) { p.tags = append(p.tags, tags...) } -func (p *ExampleProcessor) GetTags() []string { return p.tags } -func (p *ExampleProcessor) Consume(ctx context.Context) error { return nil } -func (p *ExampleProcessor) Pause(ctx context.Context) error { return nil } -func (p *ExampleProcessor) Resume(ctx context.Context) error { return nil } -func (p *ExampleProcessor) Stop(ctx context.Context) error { return nil } -func (p *ExampleProcessor) Close() error { return nil } -func (p *ExampleProcessor) GetType() string { return "example" } -func (p *ExampleProcessor) GetKey() string { return p.name } -func (p *ExampleProcessor) SetKey(key string) { p.name = key } - -// CustomActivityHook demonstrates custom activity processing -type CustomActivityHook struct { - logger logger.Logger -} - -func (h *CustomActivityHook) OnActivity(entry dag.ActivityEntry) error { - // Custom processing of activity entries - if entry.Level == dag.ActivityLevelError { - h.logger.Error("Critical activity detected", - logger.Field{Key: "activity_id", Value: entry.ID}, - logger.Field{Key: "dag_name", Value: entry.DAGName}, - logger.Field{Key: "message", Value: entry.Message}, - ) - - // Here you could send notifications, trigger alerts, etc. - } - return nil -} - -// CustomAlertHandler demonstrates custom alert handling -type CustomAlertHandler struct { - logger logger.Logger -} - -func (h *CustomAlertHandler) HandleAlert(alert dag.Alert) error { - h.logger.Warn("DAG Alert received", - logger.Field{Key: "type", Value: alert.Type}, - logger.Field{Key: "severity", Value: alert.Severity}, - logger.Field{Key: "message", Value: alert.Message}, - ) - - // Here you could integrate with external alerting systems - // like Slack, PagerDuty, email, etc. - - return nil -} - -func main() { - // Initialize logger - log := logger.New(logger.Config{ - Level: logger.LevelInfo, - Format: logger.FormatJSON, - }) - - // Create a comprehensive DAG with all enhanced features - server := mq.NewServer("demo", ":0", log) - - // Create DAG with comprehensive configuration - dagInstance := dag.NewDAG("production-workflow", "workflow-key", func(ctx context.Context, result mq.Result) { - log.Info("Workflow completed", - logger.Field{Key: "result", Value: string(result.Payload)}, - ) - }) - - // Initialize all enhanced components - setupEnhancedDAG(dagInstance, log) - - // Build the workflow - buildWorkflow(dagInstance, log) - - // Start the server and DAG - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - if err := server.Start(ctx); err != nil { - log.Error("Server failed to start", logger.Field{Key: "error", Value: err.Error()}) - } - }() - - // Wait for server to start - time.Sleep(100 * time.Millisecond) - - // Start enhanced DAG features - startEnhancedFeatures(ctx, dagInstance, log) - - // Set up HTTP API for monitoring and management - setupHTTPAPI(dagInstance, log) - - // Start the HTTP server - go func() { - log.Info("Starting HTTP server on :8080") - if err := http.ListenAndServe(":8080", nil); err != nil { - log.Error("HTTP server failed", logger.Field{Key: "error", Value: err.Error()}) - } - }() - - // Demonstrate the enhanced features - demonstrateFeatures(ctx, dagInstance, log) - - // Wait for shutdown signal - waitForShutdown(ctx, cancel, dagInstance, server, log) -} - -func setupEnhancedDAG(dagInstance *dag.DAG, log logger.Logger) { - // Initialize activity logger with memory persistence - activityConfig := dag.DefaultActivityLoggerConfig() - activityConfig.BufferSize = 500 - activityConfig.FlushInterval = 2 * time.Second - - persistence := dag.NewMemoryActivityPersistence() - dagInstance.InitializeActivityLogger(activityConfig, persistence) - - // Add custom activity hook - customHook := &CustomActivityHook{logger: log} - dagInstance.AddActivityHook(customHook) - - // Initialize monitoring with comprehensive configuration - monitorConfig := dag.MonitoringConfig{ - MetricsInterval: 5 * time.Second, - EnableHealthCheck: true, - BufferSize: 1000, - } - - alertThresholds := &dag.AlertThresholds{ - MaxFailureRate: 0.1, // 10% - MaxExecutionTime: 30 * time.Second, - MaxTasksInProgress: 100, - MinSuccessRate: 0.9, // 90% - MaxNodeFailures: 5, - HealthCheckInterval: 10 * time.Second, - } - - dagInstance.InitializeMonitoring(monitorConfig, alertThresholds) - - // Add custom alert handler - customAlertHandler := &CustomAlertHandler{logger: log} - dagInstance.AddAlertHandler(customAlertHandler) - - // Initialize configuration management - dagInstance.InitializeConfigManager() - - // Set up rate limiting - dagInstance.InitializeRateLimiter() - dagInstance.SetRateLimit("validate", 10.0, 5) // 10 req/sec, burst 5 - dagInstance.SetRateLimit("process", 20.0, 10) // 20 req/sec, burst 10 - dagInstance.SetRateLimit("finalize", 5.0, 2) // 5 req/sec, burst 2 - - // Initialize retry management - retryConfig := &dag.RetryConfig{ - MaxRetries: 3, - InitialDelay: 1 * time.Second, - MaxDelay: 10 * time.Second, - BackoffFactor: 2.0, - Jitter: true, - RetryCondition: func(err error) bool { - // Custom retry condition - retry on specific errors - return err != nil && err.Error() != "permanent_failure" - }, - } - dagInstance.InitializeRetryManager(retryConfig) - - // Initialize transaction management - txConfig := dag.TransactionConfig{ - DefaultTimeout: 5 * time.Minute, - CleanupInterval: 10 * time.Minute, - } - dagInstance.InitializeTransactionManager(txConfig) - - // Initialize cleanup management - cleanupConfig := dag.CleanupConfig{ - Interval: 5 * time.Minute, - TaskRetentionPeriod: 1 * time.Hour, - ResultRetentionPeriod: 2 * time.Hour, - MaxRetainedTasks: 1000, - } - dagInstance.InitializeCleanupManager(cleanupConfig) - - // Initialize performance optimizer - dagInstance.InitializePerformanceOptimizer() - - // Set up webhook manager for external notifications - httpClient := dag.NewSimpleHTTPClient(30 * time.Second) - webhookManager := dag.NewWebhookManager(httpClient, log) - - // Add webhook for task completion events - webhookConfig := dag.WebhookConfig{ - URL: "https://api.example.com/dag-events", // Replace with actual endpoint - Headers: map[string]string{"Authorization": "Bearer your-token"}, - RetryCount: 3, - Events: []string{"task_completed", "task_failed", "dag_completed"}, - } - webhookManager.AddWebhook("task_completed", webhookConfig) - dagInstance.SetWebhookManager(webhookManager) - - log.Info("Enhanced DAG features initialized successfully") -} - -func buildWorkflow(dagInstance *dag.DAG, log logger.Logger) { - // Create processors for each step - validator := NewExampleProcessor("validator") - processor := NewExampleProcessor("processor") - enricher := NewExampleProcessor("enricher") - finalizer := NewExampleProcessor("finalizer") - - // Build the workflow with retry configurations - retryConfig := &dag.RetryConfig{ - MaxRetries: 2, - InitialDelay: 500 * time.Millisecond, - MaxDelay: 5 * time.Second, - BackoffFactor: 2.0, - } - - dagInstance. - AddNodeWithRetry(dag.Function, "Validate Input", "validate", validator, retryConfig, true). - AddNodeWithRetry(dag.Function, "Process Data", "process", processor, retryConfig). - AddNodeWithRetry(dag.Function, "Enrich Data", "enrich", enricher, retryConfig). - AddNodeWithRetry(dag.Function, "Finalize", "finalize", finalizer, retryConfig). - Connect("validate", "process"). - Connect("process", "enrich"). - Connect("enrich", "finalize") - - // Add conditional connections - dagInstance.AddCondition("validate", "success", "process") - dagInstance.AddCondition("validate", "failure", "finalize") // Skip to finalize on validation failure - - // Validate the DAG structure - if err := dagInstance.ValidateDAG(); err != nil { - log.Error("DAG validation failed", logger.Field{Key: "error", Value: err.Error()}) - os.Exit(1) - } - - log.Info("Workflow built and validated successfully") -} - -func startEnhancedFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) { - // Start monitoring - dagInstance.StartMonitoring(ctx) - - // Start cleanup manager - dagInstance.StartCleanup(ctx) - - // Enable batch processing - dagInstance.SetBatchProcessingEnabled(true) - - log.Info("Enhanced features started") -} - -func setupHTTPAPI(dagInstance *dag.DAG, log logger.Logger) { - // Set up standard DAG handlers - dagInstance.Handlers(http.DefaultServeMux, "/dag") - - // Set up enhanced API endpoints - enhancedAPI := dag.NewEnhancedAPIHandler(dagInstance) - enhancedAPI.RegisterRoutes(http.DefaultServeMux) - - // Custom endpoints for demonstration - http.HandleFunc("/demo/activities", func(w http.ResponseWriter, r *http.Request) { - filter := dag.ActivityFilter{ - Limit: 50, - } - - activities, err := dagInstance.GetActivities(filter) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, activities); err != nil { - log.Error("Failed to write activities response", logger.Field{Key: "error", Value: err.Error()}) - } - }) - - http.HandleFunc("/demo/stats", func(w http.ResponseWriter, r *http.Request) { - stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{}) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, stats); err != nil { - log.Error("Failed to write stats response", logger.Field{Key: "error", Value: err.Error()}) - } - }) - - log.Info("HTTP API endpoints configured") -} - -func demonstrateFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) { - log.Info("Demonstrating enhanced DAG features...") - - // 1. Process a successful task - log.Info("Processing successful task...") - processTask(ctx, dagInstance, map[string]interface{}{ - "id": "task-001", - "data": "valid input data", - "type": "success", - }, log) - - // 2. Process a task that will fail - log.Info("Processing failing task...") - processTask(ctx, dagInstance, map[string]interface{}{ - "id": "task-002", - "data": nil, // This will cause processing issues - "type": "failure", - }, log) - - // 3. Process with transaction - log.Info("Processing with transaction...") - processWithTransaction(ctx, dagInstance, map[string]interface{}{ - "id": "task-003", - "data": "transaction data", - "type": "transaction", - }, log) - - // 4. Demonstrate rate limiting - log.Info("Demonstrating rate limiting...") - demonstrateRateLimiting(ctx, dagInstance, log) - - // 5. Show monitoring metrics - time.Sleep(2 * time.Second) // Allow time for metrics to accumulate - showMetrics(dagInstance, log) - - // 6. Show activity logs - showActivityLogs(dagInstance, log) -} - -func processTask(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) { - // Add context information - ctx = context.WithValue(ctx, "user_id", "demo-user") - ctx = context.WithValue(ctx, "session_id", "demo-session") - ctx = context.WithValue(ctx, "trace_id", mq.NewID()) - - result := dagInstance.Process(ctx, payload) - if result.Error != nil { - log.Error("Task processing failed", - logger.Field{Key: "error", Value: result.Error.Error()}, - logger.Field{Key: "payload", Value: payload}, - ) - } else { - log.Info("Task processed successfully", - logger.Field{Key: "result_size", Value: len(result.Payload)}, - ) - } -} - -func processWithTransaction(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) { - taskID := fmt.Sprintf("tx-%s", mq.NewID()) - - // Begin transaction - tx := dagInstance.BeginTransaction(taskID) - if tx == nil { - log.Error("Failed to begin transaction") - return - } - - // Add transaction context - ctx = context.WithValue(ctx, "transaction_id", tx.ID) - ctx = context.WithValue(ctx, "task_id", taskID) - - // Process the task - result := dagInstance.Process(ctx, payload) - - // Commit or rollback based on result - if result.Error != nil { - if err := dagInstance.RollbackTransaction(tx.ID); err != nil { - log.Error("Failed to rollback transaction", - logger.Field{Key: "tx_id", Value: tx.ID}, - logger.Field{Key: "error", Value: err.Error()}, - ) - } else { - log.Info("Transaction rolled back", - logger.Field{Key: "tx_id", Value: tx.ID}, - ) - } - } else { - if err := dagInstance.CommitTransaction(tx.ID); err != nil { - log.Error("Failed to commit transaction", - logger.Field{Key: "tx_id", Value: tx.ID}, - logger.Field{Key: "error", Value: err.Error()}, - ) - } else { - log.Info("Transaction committed", - logger.Field{Key: "tx_id", Value: tx.ID}, - ) - } - } -} - -func demonstrateRateLimiting(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) { - // Try to exceed rate limits - for i := 0; i < 15; i++ { - allowed := dagInstance.CheckRateLimit("validate") - log.Info("Rate limit check", - logger.Field{Key: "attempt", Value: i + 1}, - logger.Field{Key: "allowed", Value: allowed}, - ) - - if allowed { - processTask(ctx, dagInstance, map[string]interface{}{ - "id": fmt.Sprintf("rate-test-%d", i), - "data": "rate limiting test", - }, log) - } - - time.Sleep(100 * time.Millisecond) - } -} - -func showMetrics(dagInstance *dag.DAG, log logger.Logger) { - metrics := dagInstance.GetMonitoringMetrics() - if metrics != nil { - log.Info("Current DAG Metrics", - logger.Field{Key: "total_tasks", Value: metrics.TasksTotal}, - logger.Field{Key: "completed_tasks", Value: metrics.TasksCompleted}, - logger.Field{Key: "failed_tasks", Value: metrics.TasksFailed}, - logger.Field{Key: "tasks_in_progress", Value: metrics.TasksInProgress}, - logger.Field{Key: "avg_execution_time", Value: metrics.AverageExecutionTime.String()}, - ) - - // Show node-specific metrics - for nodeID := range map[string]bool{"validate": true, "process": true, "enrich": true, "finalize": true} { - if nodeStats := dagInstance.GetNodeStats(nodeID); nodeStats != nil { - log.Info("Node Metrics", - logger.Field{Key: "node_id", Value: nodeID}, - logger.Field{Key: "executions", Value: nodeStats.TotalExecutions}, - logger.Field{Key: "failures", Value: nodeStats.FailureCount}, - logger.Field{Key: "avg_duration", Value: nodeStats.AverageExecutionTime.String()}, - ) - } - } - } else { - log.Warn("Monitoring metrics not available") - } -} - -func showActivityLogs(dagInstance *dag.DAG, log logger.Logger) { - // Get recent activities - filter := dag.ActivityFilter{ - Limit: 10, - SortBy: "timestamp", - SortOrder: "desc", - } - - activities, err := dagInstance.GetActivities(filter) - if err != nil { - log.Error("Failed to get activities", logger.Field{Key: "error", Value: err.Error()}) - return - } - - log.Info("Recent Activities", logger.Field{Key: "count", Value: len(activities)}) - for _, activity := range activities { - log.Info("Activity", - logger.Field{Key: "id", Value: activity.ID}, - logger.Field{Key: "type", Value: string(activity.Type)}, - logger.Field{Key: "level", Value: string(activity.Level)}, - logger.Field{Key: "message", Value: activity.Message}, - logger.Field{Key: "task_id", Value: activity.TaskID}, - logger.Field{Key: "node_id", Value: activity.NodeID}, - ) - } - - // Get activity statistics - stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{}) - if err != nil { - log.Error("Failed to get activity stats", logger.Field{Key: "error", Value: err.Error()}) - return - } - - log.Info("Activity Statistics", - logger.Field{Key: "total_activities", Value: stats.TotalActivities}, - logger.Field{Key: "success_rate", Value: fmt.Sprintf("%.2f%%", stats.SuccessRate*100)}, - logger.Field{Key: "failure_rate", Value: fmt.Sprintf("%.2f%%", stats.FailureRate*100)}, - logger.Field{Key: "avg_duration", Value: stats.AverageDuration.String()}, - ) -} - -func waitForShutdown(ctx context.Context, cancel context.CancelFunc, dagInstance *dag.DAG, server *mq.Server, log logger.Logger) { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - log.Info("DAG system is running. Available endpoints:", - logger.Field{Key: "workflow", Value: "http://localhost:8080/dag/"}, - logger.Field{Key: "process", Value: "http://localhost:8080/dag/process"}, - logger.Field{Key: "metrics", Value: "http://localhost:8080/api/dag/metrics"}, - logger.Field{Key: "health", Value: "http://localhost:8080/api/dag/health"}, - logger.Field{Key: "activities", Value: "http://localhost:8080/demo/activities"}, - logger.Field{Key: "stats", Value: "http://localhost:8080/demo/stats"}, - ) - - <-sigChan - log.Info("Shutdown signal received, cleaning up...") - - // Graceful shutdown - cancel() - - // Stop enhanced features - dagInstance.StopEnhanced(ctx) - - // Stop server - if err := server.Stop(ctx); err != nil { - log.Error("Error stopping server", logger.Field{Key: "error", Value: err.Error()}) - } - - log.Info("Shutdown complete") -} diff --git a/examples/publisher.go b/examples/publisher.go index 04052d0..f823f96 100644 --- a/examples/publisher.go +++ b/examples/publisher.go @@ -13,7 +13,7 @@ func main() { Payload: payload, } publisher := mq.NewPublisher("publish-1", mq.WithBrokerURL(":8081")) - for i := 0; i < 10000000; i++ { + for i := 0; i < 2; i++ { // publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) err := publisher.Publish(context.Background(), task, "queue1") if err != nil { diff --git a/examples/sms_form_dag.go b/examples/sms_form_dag.go index 888c8be..316f164 100644 --- a/examples/sms_form_dag.go +++ b/examples/sms_form_dag.go @@ -10,6 +10,7 @@ import ( "github.com/oarkflow/json" "github.com/oarkflow/mq/dag" + "github.com/oarkflow/mq/utils" "github.com/oarkflow/jet" @@ -19,7 +20,7 @@ import ( func main() { flow := dag.NewDAG("SMS Sender", "sms-sender", func(taskID string, result mq.Result) { - fmt.Printf("SMS workflow completed for task %s: %s\n", taskID, string(result.Payload)) + fmt.Printf("SMS workflow completed for task %s: %s\n", taskID, string(utils.RemoveRecursiveFromJSON(result.Payload, "html_content"))) }) // Add SMS workflow nodes diff --git a/go.mod b/go.mod index ab50191..79292fd 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,8 @@ require ( github.com/oarkflow/log v1.0.79 github.com/oarkflow/xid v1.2.5 github.com/prometheus/client_golang v1.21.1 + github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.33.0 golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 golang.org/x/time v0.11.0 ) @@ -23,13 +25,16 @@ require ( github.com/andybalholm/brotli v1.1.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/goccy/go-reflect v1.2.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.63.0 // indirect github.com/prometheus/procfs v0.16.0 // indirect @@ -38,4 +43,5 @@ require ( github.com/valyala/fasthttp v1.59.0 // indirect golang.org/x/sys v0.31.0 // indirect google.golang.org/protobuf v1.36.6 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4830dda..65a8982 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/goccy/go-reflect v1.2.0 h1:O0T8rZCuNmGXewnATuKYnkL0xm6o8UNOJZd/gOkb9ms= @@ -18,6 +19,10 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -59,6 +64,8 @@ github.com/prometheus/procfs v0.16.0/go.mod h1:8veyXUu3nGP7oaCxhX6yeaM5u4stL2FeM github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -67,6 +74,8 @@ github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDp github.com/valyala/fasthttp v1.59.0/go.mod h1:GTxNb9Bc6r2a9D0TWNSPwDz78UxnTGBViY3xZNEqyYU= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw= golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -76,5 +85,8 @@ golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mq.go b/mq.go index 3126617..25d2231 100644 --- a/mq.go +++ b/mq.go @@ -688,7 +688,10 @@ func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error { if err != nil { return err } - msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap()) + msg, err := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap()) + if err != nil { + return fmt.Errorf("failed to create PUBLISH message: %w", err) + } b.broadcastToConsumers(msg) return nil } @@ -698,7 +701,11 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M taskID, _ := jsonparser.GetString(msg.Payload, "id") log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID) - ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) + ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) + if err != nil { + log.Printf("Error creating PUBLISH_ACK message: %v\n", err) + return + } if err := b.send(ctx, conn, ack); err != nil { log.Printf("Error sending PUBLISH_ACK: %v\n", err) } @@ -713,7 +720,11 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { consumerID := b.AddConsumer(ctx, msg.Queue, conn) - ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers) + ack, err := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers) + if err != nil { + log.Printf("Error creating SUBSCRIBE_ACK message: %v\n", err) + return + } if err := b.send(ctx, conn, ack); err != nil { log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) } @@ -894,8 +905,12 @@ func (b *Broker) handleConsumer( fn := func(queue *Queue) { con, ok := queue.consumers.Get(consumerID) if ok { - ack := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID}) - err := b.send(ctx, con.conn, ack) + ack, err := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID}) + if err != nil { + log.Printf("Error creating message for consumer %s: %v", consumerID, err) + return + } + err = b.send(ctx, con.conn, ack) if err == nil { con.state = state } @@ -921,7 +936,11 @@ func (b *Broker) UpdateConsumer(ctx context.Context, consumerID string, config D fn := func(queue *Queue) error { con, ok := queue.consumers.Get(consumerID) if ok { - ack := codec.NewMessage(consts.CONSUMER_UPDATE, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID}) + ack, err := codec.NewMessage(consts.CONSUMER_UPDATE, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID}) + if err != nil { + log.Printf("Error creating message for consumer %s: %v", consumerID, err) + return err + } return b.send(ctx, con.conn, ack) } return nil diff --git a/pool.go b/pool.go index a8e8265..7a3a023 100644 --- a/pool.go +++ b/pool.go @@ -478,6 +478,7 @@ func (wp *Pool) processNextBatch() { if len(tasks) > 0 { for _, task := range tasks { if task != nil && !wp.gracefulShutdown { + wp.taskCompletionNotifier.Add(1) wp.handleTask(task) } } @@ -487,6 +488,8 @@ func (wp *Pool) processNextBatch() { func (wp *Pool) handleTask(task *QueueTask) { if task == nil || task.payload == nil { wp.logger.Warn().Msg("Received nil task or payload") + // Only call Done if Add was called (which is now only for actual tasks) + wp.taskCompletionNotifier.Done() return } @@ -803,9 +806,6 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er wp.taskAvailableCond.Signal() wp.taskAvailableCond.L.Unlock() - // Track pending task - wp.taskCompletionNotifier.Add(1) - // Update metrics atomic.AddInt64(&wp.metrics.TotalScheduled, 1) diff --git a/publisher.go b/publisher.go index 6f8dc6c..e2c82ec 100644 --- a/publisher.go +++ b/publisher.go @@ -7,11 +7,11 @@ import ( "net" "sync" "time" - + "github.com/oarkflow/json" - + "github.com/oarkflow/json/jsonparser" - + "github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/utils" @@ -111,11 +111,14 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net. if err != nil { return err } - msg := codec.NewMessage(command, payload, queue, headers) + msg, err := codec.NewMessage(command, payload, queue, headers) + if err != nil { + return err + } if err := codec.SendMessage(ctx, conn, msg); err != nil { return err } - + return p.waitForAck(ctx, conn) } diff --git a/utils/json.go b/utils/json.go new file mode 100644 index 0000000..cbb25cd --- /dev/null +++ b/utils/json.go @@ -0,0 +1,45 @@ +package utils + +import ( + "github.com/oarkflow/json" + "github.com/oarkflow/json/jsonparser" +) + +func RemoveFromJSONBye(jsonStr json.RawMessage, key ...string) json.RawMessage { + return jsonparser.Delete(jsonStr, key...) +} + +func RemoveRecursiveFromJSON(jsonStr json.RawMessage, key ...string) json.RawMessage { + var data any + if err := json.Unmarshal(jsonStr, &data); err != nil { + return jsonStr + } + + for _, k := range key { + data = removeKeyRecursive(data, k) + } + + result, err := json.Marshal(data) + if err != nil { + return jsonStr + } + return result +} + +func removeKeyRecursive(data any, key string) any { + switch v := data.(type) { + case map[string]any: + delete(v, key) + for k, val := range v { + v[k] = removeKeyRecursive(val, key) + } + return v + case []any: + for i, item := range v { + v[i] = removeKeyRecursive(item, key) + } + return v + default: + return v + } +}