mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
feat: use "GetTags" and "SetTags"
This commit is contained in:
@@ -1,343 +0,0 @@
|
||||
// apperror/apperror.go
|
||||
package apperror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// APP_ENV values
|
||||
const (
|
||||
EnvDevelopment = "development"
|
||||
EnvStaging = "staging"
|
||||
EnvProduction = "production"
|
||||
)
|
||||
|
||||
// AppError defines a structured application error
|
||||
type AppError struct {
|
||||
Code string `json:"code"` // 9-digit code: XXX|AA|DD|YY
|
||||
Message string `json:"message"` // human-readable message
|
||||
StatusCode int `json:"-"` // HTTP status, not serialized
|
||||
Err error `json:"-"` // wrapped error, not serialized
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // optional extra info
|
||||
StackTrace []string `json:"stackTrace,omitempty"`
|
||||
}
|
||||
|
||||
// Error implements error interface
|
||||
func (e *AppError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap enables errors.Is / errors.As
|
||||
func (e *AppError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// WithMetadata returns a shallow copy with added metadata key/value
|
||||
func (e *AppError) WithMetadata(key string, val any) *AppError {
|
||||
newMD := make(map[string]any, len(e.Metadata)+1)
|
||||
for k, v := range e.Metadata {
|
||||
newMD[k] = v
|
||||
}
|
||||
newMD[key] = val
|
||||
|
||||
return &AppError{
|
||||
Code: e.Code,
|
||||
Message: e.Message,
|
||||
StatusCode: e.StatusCode,
|
||||
Err: e.Err,
|
||||
Metadata: newMD,
|
||||
StackTrace: e.StackTrace,
|
||||
}
|
||||
}
|
||||
|
||||
// GetStackTraceArray returns the error stack trace as an array of strings
|
||||
func (e *AppError) GetStackTraceArray() []string {
|
||||
return e.StackTrace
|
||||
}
|
||||
|
||||
// GetStackTraceString returns the error stack trace as a single string
|
||||
func (e *AppError) GetStackTraceString() string {
|
||||
return strings.Join(e.StackTrace, "\n")
|
||||
}
|
||||
|
||||
// captureStackTrace returns a slice of strings representing the stack trace.
|
||||
func captureStackTrace() []string {
|
||||
const depth = 32
|
||||
var pcs [depth]uintptr
|
||||
n := runtime.Callers(3, pcs[:])
|
||||
frames := runtime.CallersFrames(pcs[:n])
|
||||
isDebug := os.Getenv("APP_DEBUG") == "true"
|
||||
var stack []string
|
||||
for {
|
||||
frame, more := frames.Next()
|
||||
var file string
|
||||
if !isDebug {
|
||||
file = "/" + filepath.Base(frame.File)
|
||||
} else {
|
||||
file = frame.File
|
||||
}
|
||||
if strings.HasSuffix(file, ".go") {
|
||||
file = strings.TrimSuffix(file, ".go") + ".sec"
|
||||
}
|
||||
stack = append(stack, fmt.Sprintf("%s:%d %s", file, frame.Line, frame.Function))
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
}
|
||||
return stack
|
||||
}
|
||||
|
||||
// buildCode constructs a 9-digit code: XXX|AA|DD|YY
|
||||
func buildCode(httpCode, appCode, domainCode, errCode int) string {
|
||||
return fmt.Sprintf("%03d%02d%02d%02d", httpCode, appCode, domainCode, errCode)
|
||||
}
|
||||
|
||||
// New creates a fresh AppError
|
||||
func New(httpCode, appCode, domainCode, errCode int, msg string) *AppError {
|
||||
return &AppError{
|
||||
Code: buildCode(httpCode, appCode, domainCode, errCode),
|
||||
Message: msg,
|
||||
StatusCode: httpCode,
|
||||
// Prototype: no StackTrace captured at registration time.
|
||||
}
|
||||
}
|
||||
|
||||
// Modify Wrap to always capture a fresh stack trace.
|
||||
func Wrap(err error, httpCode, appCode, domainCode, errCode int, msg string) *AppError {
|
||||
return &AppError{
|
||||
Code: buildCode(httpCode, appCode, domainCode, errCode),
|
||||
Message: msg,
|
||||
StatusCode: httpCode,
|
||||
Err: err,
|
||||
StackTrace: captureStackTrace(),
|
||||
}
|
||||
}
|
||||
|
||||
// New helper: Instance attaches the runtime stack trace to a prototype error.
|
||||
func Instance(e *AppError) *AppError {
|
||||
// Create a shallow copy and attach the current stack trace.
|
||||
copyE := *e
|
||||
copyE.StackTrace = captureStackTrace()
|
||||
return ©E
|
||||
}
|
||||
|
||||
// Modify toAppError to instance a prototype if it lacks a stack trace.
|
||||
func toAppError(err error) *AppError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var ae *AppError
|
||||
if errors.As(err, &ae) {
|
||||
if len(ae.StackTrace) == 0 { // Prototype without context.
|
||||
return Instance(ae)
|
||||
}
|
||||
return ae
|
||||
}
|
||||
// fallback to internal error 500|00|00|00 with fresh stack trace.
|
||||
return Wrap(err, http.StatusInternalServerError, 0, 0, 0, "Internal server error")
|
||||
}
|
||||
|
||||
// onError, if set, is called before writing any JSON error
|
||||
var onError func(*AppError)
|
||||
|
||||
func OnError(hook func(*AppError)) {
|
||||
onError = hook
|
||||
}
|
||||
|
||||
// WriteJSONError writes an error as JSON, includes X-Request-ID, hides details in production
|
||||
func WriteJSONError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
appErr := toAppError(err)
|
||||
|
||||
// attach request ID
|
||||
if rid := r.Header.Get("X-Request-ID"); rid != "" {
|
||||
appErr = appErr.WithMetadata("request_id", rid)
|
||||
}
|
||||
// hook
|
||||
if onError != nil {
|
||||
onError(appErr)
|
||||
}
|
||||
// If no stack trace is present, capture current context stack trace.
|
||||
if os.Getenv("APP_ENV") != EnvProduction {
|
||||
appErr.StackTrace = captureStackTrace()
|
||||
}
|
||||
|
||||
fmt.Println(appErr.StackTrace)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(appErr.StatusCode)
|
||||
|
||||
resp := map[string]any{
|
||||
"code": appErr.Code,
|
||||
"message": appErr.Message,
|
||||
}
|
||||
if len(appErr.Metadata) > 0 {
|
||||
resp["metadata"] = appErr.Metadata
|
||||
}
|
||||
if os.Getenv("APP_ENV") != EnvProduction {
|
||||
resp["stack"] = appErr.StackTrace
|
||||
}
|
||||
if appErr.Err != nil {
|
||||
resp["details"] = appErr.Err.Error()
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
type ErrorRegistry struct {
|
||||
registry map[string]*AppError
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (er *ErrorRegistry) Get(name string) (*AppError, bool) {
|
||||
er.mu.RLock()
|
||||
defer er.mu.RUnlock()
|
||||
e, ok := er.registry[name]
|
||||
return e, ok
|
||||
}
|
||||
|
||||
func (er *ErrorRegistry) Set(name string, e *AppError) {
|
||||
er.mu.Lock()
|
||||
defer er.mu.Unlock()
|
||||
er.registry[name] = e
|
||||
}
|
||||
|
||||
func (er *ErrorRegistry) Delete(name string) {
|
||||
er.mu.Lock()
|
||||
defer er.mu.Unlock()
|
||||
delete(er.registry, name)
|
||||
}
|
||||
|
||||
func (er *ErrorRegistry) List() []*AppError {
|
||||
er.mu.RLock()
|
||||
defer er.mu.RUnlock()
|
||||
out := make([]*AppError, 0, len(er.registry))
|
||||
for _, e := range er.registry {
|
||||
// create a shallow copy and remove the StackTrace for listing
|
||||
copyE := *e
|
||||
copyE.StackTrace = nil
|
||||
out = append(out, ©E)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (er *ErrorRegistry) GetByCode(code string) (*AppError, bool) {
|
||||
er.mu.RLock()
|
||||
defer er.mu.RUnlock()
|
||||
for _, e := range er.registry {
|
||||
if e.Code == code {
|
||||
return e, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var (
|
||||
registry *ErrorRegistry
|
||||
)
|
||||
|
||||
// Register adds a named error; fails if name exists
|
||||
func Register(name string, e *AppError) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("error name cannot be empty")
|
||||
}
|
||||
registry.Set(name, e)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update replaces an existing named error; fails if not found
|
||||
func Update(name string, e *AppError) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("error name cannot be empty")
|
||||
}
|
||||
registry.Set(name, e)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unregister removes a named error
|
||||
func Unregister(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("error name cannot be empty")
|
||||
}
|
||||
registry.Delete(name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a named error
|
||||
func Get(name string) (*AppError, bool) {
|
||||
return registry.Get(name)
|
||||
}
|
||||
|
||||
// GetByCode retrieves an error by its 9-digit code
|
||||
func GetByCode(code string) (*AppError, bool) {
|
||||
if code == "" {
|
||||
return nil, false
|
||||
}
|
||||
return registry.GetByCode(code)
|
||||
}
|
||||
|
||||
// List returns all registered errors
|
||||
func List() []*AppError {
|
||||
return registry.List()
|
||||
}
|
||||
|
||||
// Is/As shortcuts updated to check all registered errors
|
||||
func Is(err, target error) bool {
|
||||
if errors.Is(err, target) {
|
||||
return true
|
||||
}
|
||||
registry.mu.RLock()
|
||||
defer registry.mu.RUnlock()
|
||||
for _, e := range registry.registry {
|
||||
if errors.Is(err, e) || errors.Is(e, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func As(err error, target any) bool {
|
||||
if errors.As(err, target) {
|
||||
return true
|
||||
}
|
||||
registry.mu.RLock()
|
||||
defer registry.mu.RUnlock()
|
||||
for _, e := range registry.registry {
|
||||
if errors.As(err, target) || errors.As(e, target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HTTPMiddleware catches panics and converts to JSON 500
|
||||
func HTTPMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
p := fmt.Errorf("panic: %v", rec)
|
||||
WriteJSONError(w, r, Wrap(p, http.StatusInternalServerError, 0, 0, 0, "Internal server error"))
|
||||
}
|
||||
}()
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// preload some common errors (with 2-digit app/domain codes)
|
||||
func init() {
|
||||
registry = &ErrorRegistry{registry: make(map[string]*AppError)}
|
||||
_ = Register("ErrNotFound", New(http.StatusNotFound, 1, 1, 1, "Resource not found")) // → "404010101"
|
||||
_ = Register("ErrInvalidInput", New(http.StatusBadRequest, 1, 1, 2, "Invalid input provided")) // → "400010102"
|
||||
_ = Register("ErrInternal", New(http.StatusInternalServerError, 1, 1, 0, "Internal server error")) // → "500010100"
|
||||
_ = Register("ErrUnauthorized", New(http.StatusUnauthorized, 1, 1, 3, "Unauthorized")) // → "401010103"
|
||||
_ = Register("ErrForbidden", New(http.StatusForbidden, 1, 1, 4, "Forbidden")) // → "403010104"
|
||||
}
|
181
codec/README.md
Normal file
181
codec/README.md
Normal file
@@ -0,0 +1,181 @@
|
||||
# Message Queue Codec
|
||||
|
||||
This package provides a robust, production-ready codec implementation for serializing, transmitting, and deserializing messages in a distributed messaging system.
|
||||
|
||||
## Features
|
||||
|
||||
- **Message Validation**: Comprehensive validation of message format and content
|
||||
- **Efficient Serialization**: Pluggable serialization with JSON as default
|
||||
- **Compression**: Optional payload compression for large messages
|
||||
- **Encryption**: Optional message encryption for sensitive data
|
||||
- **Large Message Support**: Automatic fragmentation and reassembly of large messages
|
||||
- **Connection Health**: Heartbeat mechanism for connection monitoring
|
||||
- **Performance Optimized**: Buffer pooling, efficient memory usage
|
||||
- **Robust Error Handling**: Detailed error types and error wrapping
|
||||
- **Timeout Management**: Context-aware deadline handling
|
||||
- **Observability**: Built-in statistics tracking
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Message Sending/Receiving
|
||||
|
||||
```go
|
||||
// Create a message
|
||||
msg, err := codec.NewMessage(consts.CmdPublish, payload, "my-queue", headers)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create message: %v", err)
|
||||
}
|
||||
|
||||
// Send the message
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
codec := codec.NewCodec(codec.DefaultConfig())
|
||||
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
||||
log.Fatalf("Failed to send message: %v", err)
|
||||
}
|
||||
|
||||
// Receive a message
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
receivedMsg, err := codec.ReadMessage(ctx, conn)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to receive message: %v", err)
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Serialization
|
||||
|
||||
```go
|
||||
// Set a custom marshaller/unmarshaller
|
||||
codec.SetMarshaller(func(v any) ([]byte, error) {
|
||||
// Custom serialization logic
|
||||
return someCustomSerializer.Marshal(v)
|
||||
})
|
||||
|
||||
codec.SetUnmarshaller(func(data []byte, v any) error {
|
||||
// Custom deserialization logic
|
||||
return someCustomSerializer.Unmarshal(data, v)
|
||||
})
|
||||
```
|
||||
|
||||
### Enabling Compression
|
||||
|
||||
```go
|
||||
// Enable compression globally
|
||||
codec.EnableCompression(true)
|
||||
|
||||
// Or configure it per codec instance
|
||||
config := codec.DefaultConfig()
|
||||
config.EnableCompression = true
|
||||
codec := codec.NewCodec(config)
|
||||
```
|
||||
|
||||
### Enabling Encryption
|
||||
|
||||
```go
|
||||
// Generate a secure key
|
||||
key := make([]byte, 32) // 256-bit key
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
log.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
// Enable encryption globally
|
||||
codec.EnableEncryption(true, key)
|
||||
|
||||
// Or configure it per serialization manager
|
||||
config := codec.DefaultSerializationConfig()
|
||||
config.EnableEncryption = true
|
||||
config.EncryptionKey = key
|
||||
config.PreferredCipher = "chacha20poly1305" // or "aes-gcm"
|
||||
```
|
||||
|
||||
### Connection Health Monitoring
|
||||
|
||||
```go
|
||||
// Create a heartbeat manager
|
||||
codec := codec.NewCodec(codec.DefaultConfig())
|
||||
hm := codec.NewHeartbeatManager(codec, conn)
|
||||
|
||||
// Configure heartbeat
|
||||
hm.SetInterval(15 * time.Second)
|
||||
hm.SetTimeout(45 * time.Second)
|
||||
hm.SetOnFailure(func(err error) {
|
||||
log.Printf("Heartbeat failure: %v", err)
|
||||
// Take action like closing connection
|
||||
})
|
||||
|
||||
// Start heartbeat monitoring
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The codec behavior can be customized through the `Config` struct:
|
||||
|
||||
```go
|
||||
config := &codec.Config{
|
||||
MaxMessageSize: 32 * 1024 * 1024, // 32MB max message size
|
||||
MaxHeaderSize: 64 * 1024, // 64KB max header size
|
||||
MaxQueueLength: 128, // Max queue name length
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
EnableCompression: true,
|
||||
BufferPoolSize: 2000,
|
||||
}
|
||||
|
||||
codec := codec.NewCodec(config)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The codec provides detailed error types for different failure scenarios:
|
||||
|
||||
- `ErrMessageTooLarge`: Message exceeds maximum size
|
||||
- `ErrInvalidMessage`: Invalid message format
|
||||
- `ErrInvalidQueue`: Invalid queue name
|
||||
- `ErrInvalidCommand`: Invalid command
|
||||
- `ErrConnectionClosed`: Connection closed
|
||||
- `ErrTimeout`: Operation timeout
|
||||
- `ErrProtocolMismatch`: Protocol version mismatch
|
||||
- `ErrFragmentationRequired`: Message requires fragmentation
|
||||
- `ErrInvalidFragment`: Invalid message fragment
|
||||
- `ErrFragmentTimeout`: Timed out waiting for fragments
|
||||
- `ErrFragmentMissing`: Missing fragments in sequence
|
||||
|
||||
Error handling example:
|
||||
|
||||
```go
|
||||
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
||||
if errors.Is(err, codec.ErrMessageTooLarge) {
|
||||
// Handle message size error
|
||||
} else if errors.Is(err, codec.ErrTimeout) {
|
||||
// Handle timeout error
|
||||
} else {
|
||||
// Handle other errors
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The codec package includes testing utilities for validation and performance testing:
|
||||
|
||||
```go
|
||||
ts := codec.NewCodecTestSuite()
|
||||
|
||||
// Test basic message sending/receiving
|
||||
msg, _ := codec.NewMessage(consts.CmdPublish, []byte("test"), "test-queue", nil)
|
||||
if err := ts.SendReceiveTest(msg); err != nil {
|
||||
log.Fatalf("Test failed: %v", err)
|
||||
}
|
||||
|
||||
// Test fragmentation/reassembly
|
||||
largePayload := make([]byte, 20*1024*1024) // 20MB payload
|
||||
rand.Read(largePayload) // Fill with random data
|
||||
if err := ts.FragmentationTest(largePayload); err != nil {
|
||||
log.Fatalf("Fragmentation test failed: %v", err)
|
||||
}
|
||||
```
|
465
codec/codec.go
465
codec/codec.go
@@ -4,103 +4,492 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io" // added for full reads
|
||||
"net"
|
||||
"sync"
|
||||
"time" // added for handling deadlines
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/internal/bpool"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Headers map[string]string `msgpack:"h"`
|
||||
Queue string `msgpack:"q"`
|
||||
Payload []byte `msgpack:"p"`
|
||||
Command consts.CMD `msgpack:"c"`
|
||||
// Protocol version for backward compatibility
|
||||
const (
|
||||
ProtocolVersion = uint8(1)
|
||||
MaxMessageSize = 64 * 1024 * 1024 // 64MB default limit
|
||||
MaxHeaderSize = 1024 * 1024 // 1MB header limit
|
||||
MaxQueueLength = 255 // Max queue name length
|
||||
FragmentationThreshold = 16 * 1024 * 1024 // Messages larger than 16MB will be fragmented
|
||||
FragmentSize = 8 * 1024 * 1024 // 8MB fragment size
|
||||
MaxFragments = 256 // Maximum fragments per message
|
||||
)
|
||||
|
||||
// Error definitions
|
||||
var (
|
||||
ErrMessageTooLarge = errors.New("message exceeds maximum size")
|
||||
ErrInvalidMessage = errors.New("invalid message format")
|
||||
ErrInvalidQueue = errors.New("invalid queue name")
|
||||
ErrInvalidCommand = errors.New("invalid command")
|
||||
ErrConnectionClosed = errors.New("connection closed")
|
||||
ErrTimeout = errors.New("operation timeout")
|
||||
ErrProtocolMismatch = errors.New("protocol version mismatch")
|
||||
ErrFragmentationRequired = errors.New("message requires fragmentation")
|
||||
ErrInvalidFragment = errors.New("invalid message fragment")
|
||||
ErrFragmentTimeout = errors.New("timed out waiting for fragments")
|
||||
ErrFragmentMissing = errors.New("missing fragments in sequence")
|
||||
)
|
||||
|
||||
// Config holds codec configuration
|
||||
type Config struct {
|
||||
MaxMessageSize uint32
|
||||
MaxHeaderSize uint32
|
||||
MaxQueueLength uint8
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
EnableCompression bool
|
||||
BufferPoolSize int
|
||||
}
|
||||
|
||||
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
|
||||
// DefaultConfig returns default configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
MaxMessageSize: MaxMessageSize,
|
||||
MaxHeaderSize: MaxHeaderSize,
|
||||
MaxQueueLength: MaxQueueLength,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
EnableCompression: false,
|
||||
BufferPoolSize: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
// MessageType indicates the type of message being sent
|
||||
type MessageType uint8
|
||||
|
||||
const (
|
||||
MessageTypeStandard MessageType = iota
|
||||
MessageTypeFragment
|
||||
MessageTypeHeartbeat
|
||||
MessageTypeAck
|
||||
MessageTypeError
|
||||
)
|
||||
|
||||
// MessageFlag represents various flags that can be set on messages
|
||||
type MessageFlag uint16
|
||||
|
||||
const (
|
||||
FlagNone MessageFlag = 0
|
||||
FlagFragmented MessageFlag = 1 << iota
|
||||
FlagCompressed
|
||||
FlagEncrypted
|
||||
FlagHighPriority
|
||||
FlagRedelivered
|
||||
FlagNoAck
|
||||
)
|
||||
|
||||
// Message represents a protocol message with validation
|
||||
type Message struct {
|
||||
Headers map[string]string `msgpack:"h" json:"headers"`
|
||||
Queue string `msgpack:"q" json:"queue"`
|
||||
Payload []byte `msgpack:"p" json:"payload"`
|
||||
Command consts.CMD `msgpack:"c" json:"command"`
|
||||
Version uint8 `msgpack:"v" json:"version"`
|
||||
Timestamp int64 `msgpack:"t" json:"timestamp"`
|
||||
ID string `msgpack:"i" json:"id,omitempty"`
|
||||
Flags MessageFlag `msgpack:"f" json:"flags"`
|
||||
Type MessageType `msgpack:"mt" json:"messageType"`
|
||||
FragmentID uint32 `msgpack:"fid" json:"fragmentId,omitempty"`
|
||||
Fragments uint16 `msgpack:"fs" json:"fragments,omitempty"`
|
||||
Sequence uint16 `msgpack:"seq" json:"sequence,omitempty"`
|
||||
}
|
||||
|
||||
// Codec handles message encoding/decoding with configuration
|
||||
type Codec struct {
|
||||
config *Config
|
||||
mu sync.RWMutex
|
||||
stats *Stats
|
||||
}
|
||||
|
||||
// Stats tracks codec statistics
|
||||
type Stats struct {
|
||||
MessagesSent uint64
|
||||
MessagesReceived uint64
|
||||
BytesSent uint64
|
||||
BytesReceived uint64
|
||||
Errors uint64
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCodec creates a new codec with configuration
|
||||
func NewCodec(config *Config) *Codec {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
return &Codec{
|
||||
config: config,
|
||||
stats: &Stats{},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessage creates a validated message
|
||||
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) (*Message, error) {
|
||||
if err := validateCommand(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateQueue(queue); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if headers == nil {
|
||||
headers = make(map[string]string)
|
||||
}
|
||||
return &Message{
|
||||
Headers: headers,
|
||||
Queue: queue,
|
||||
Command: cmd,
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) Serialize() ([]byte, error) {
|
||||
data, err := Marshal(m)
|
||||
if err != nil {
|
||||
if err := validateHeaders(headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Message{
|
||||
Headers: headers,
|
||||
Queue: queue,
|
||||
Command: cmd,
|
||||
Payload: payload,
|
||||
Version: ProtocolVersion,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate performs message validation
|
||||
func (m *Message) Validate(config *Config) error {
|
||||
if m == nil {
|
||||
return ErrInvalidMessage
|
||||
}
|
||||
|
||||
if m.Version != ProtocolVersion {
|
||||
return ErrProtocolMismatch
|
||||
}
|
||||
|
||||
if err := validateCommand(m.Command); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateQueue(m.Queue); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateHeaders(m.Headers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(m.Payload) > int(config.MaxMessageSize) {
|
||||
return ErrMessageTooLarge
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serialize converts message to bytes with validation
|
||||
func (m *Message) Serialize() ([]byte, error) {
|
||||
if m == nil {
|
||||
return nil, ErrInvalidMessage
|
||||
}
|
||||
|
||||
data, err := Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialization failed: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Deserialize converts bytes to message with validation
|
||||
func Deserialize(data []byte) (*Message, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, ErrInvalidMessage
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := Unmarshal(data, &msg); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("deserialization failed: %w", err)
|
||||
}
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
|
||||
// SendMessage sends a message with proper error handling and timeouts
|
||||
func (c *Codec) SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
|
||||
// Check context cancellation before proceeding
|
||||
if err := ctx.Err(); err != nil {
|
||||
c.incrementErrors()
|
||||
return fmt.Errorf("context ended before send: %w", err)
|
||||
}
|
||||
|
||||
if msg == nil {
|
||||
return ErrInvalidMessage
|
||||
}
|
||||
|
||||
// Validate message
|
||||
if err := msg.Validate(c.config); err != nil {
|
||||
c.incrementErrors()
|
||||
return fmt.Errorf("message validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is a fragment message, if so handle it directly
|
||||
if msg.Type == MessageTypeFragment {
|
||||
return c.sendRawMessage(ctx, conn, msg)
|
||||
}
|
||||
|
||||
// Handle fragmentation for large messages if needed
|
||||
if len(msg.Payload) > int(FragmentationThreshold) && msg.Type != MessageTypeFragment {
|
||||
fm := NewFragmentManager(c, c.config)
|
||||
defer fm.Stop()
|
||||
return fm.sendFragmentedMessage(ctx, conn, msg)
|
||||
}
|
||||
|
||||
// Standard message send path
|
||||
return c.sendRawMessage(ctx, conn, msg)
|
||||
}
|
||||
|
||||
// sendRawMessage handles the actual sending of a message or fragment
|
||||
func (c *Codec) sendRawMessage(ctx context.Context, conn net.Conn, msg *Message) error {
|
||||
// Serialize message
|
||||
data, err := msg.Serialize()
|
||||
if err != nil {
|
||||
return err
|
||||
c.incrementErrors()
|
||||
return fmt.Errorf("message serialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Check message size
|
||||
if len(data) > int(c.config.MaxMessageSize) {
|
||||
c.incrementErrors()
|
||||
return ErrMessageTooLarge
|
||||
}
|
||||
|
||||
// Prepare buffer
|
||||
totalLength := 4 + len(data)
|
||||
buffer := bpool.Get()
|
||||
defer bpool.Put(buffer)
|
||||
buffer.Reset()
|
||||
|
||||
if cap(buffer.B) < totalLength {
|
||||
buffer.B = make([]byte, totalLength)
|
||||
} else {
|
||||
buffer.B = buffer.B[:totalLength]
|
||||
}
|
||||
|
||||
// Write length prefix and data
|
||||
binary.BigEndian.PutUint32(buffer.B[:4], uint32(len(data)))
|
||||
copy(buffer.B[4:], data)
|
||||
|
||||
// Set timeout
|
||||
deadline := time.Now().Add(c.config.WriteTimeout)
|
||||
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) {
|
||||
deadline = ctxDeadline
|
||||
}
|
||||
|
||||
if err := conn.SetWriteDeadline(deadline); err != nil {
|
||||
c.incrementErrors()
|
||||
return fmt.Errorf("failed to set write deadline: %w", err)
|
||||
}
|
||||
defer conn.SetWriteDeadline(time.Time{})
|
||||
|
||||
// Write with buffering
|
||||
writer := bufio.NewWriter(conn)
|
||||
|
||||
// Set write deadline if context has one
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
conn.SetWriteDeadline(deadline)
|
||||
defer conn.SetWriteDeadline(time.Time{})
|
||||
written, err := writer.Write(buffer.B[:totalLength])
|
||||
if err != nil {
|
||||
c.incrementErrors()
|
||||
return fmt.Errorf("write failed after %d bytes: %w", written, err)
|
||||
}
|
||||
|
||||
// Write full data
|
||||
if _, err := writer.Write(buffer.B[:totalLength]); err != nil {
|
||||
return err
|
||||
if err := writer.Flush(); err != nil {
|
||||
c.incrementErrors()
|
||||
return fmt.Errorf("flush failed: %w", err)
|
||||
}
|
||||
return writer.Flush()
|
||||
|
||||
// Update statistics
|
||||
c.stats.mu.Lock()
|
||||
c.stats.MessagesSent++
|
||||
c.stats.BytesSent += uint64(totalLength)
|
||||
c.stats.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
|
||||
// ReadMessage reads a message with proper error handling and timeouts
|
||||
func (c *Codec) ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
|
||||
// Check context cancellation before proceeding
|
||||
if err := ctx.Err(); err != nil {
|
||||
c.incrementErrors()
|
||||
return nil, fmt.Errorf("context ended before read: %w", err)
|
||||
}
|
||||
|
||||
// Check context cancellation before proceeding
|
||||
if err := ctx.Err(); err != nil {
|
||||
c.incrementErrors()
|
||||
return nil, fmt.Errorf("context ended before read: %w", err)
|
||||
}
|
||||
|
||||
// Set timeout
|
||||
deadline := time.Now().Add(c.config.ReadTimeout)
|
||||
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) {
|
||||
deadline = ctxDeadline
|
||||
}
|
||||
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
c.incrementErrors()
|
||||
return nil, fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
defer conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Read length prefix
|
||||
lengthBytes := make([]byte, 4)
|
||||
// Set read deadline if context has one
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
conn.SetReadDeadline(deadline)
|
||||
defer conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
// Use io.ReadFull to ensure all header bytes are read
|
||||
if _, err := io.ReadFull(conn, lengthBytes); err != nil {
|
||||
return nil, err
|
||||
c.incrementErrors()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read message length: %w", err)
|
||||
}
|
||||
|
||||
length := binary.BigEndian.Uint32(lengthBytes)
|
||||
|
||||
// Validate message size
|
||||
if length > c.config.MaxMessageSize {
|
||||
c.incrementErrors()
|
||||
return nil, ErrMessageTooLarge
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
c.incrementErrors()
|
||||
return nil, ErrInvalidMessage
|
||||
}
|
||||
|
||||
// Read message data
|
||||
buffer := bpool.Get()
|
||||
defer bpool.Put(buffer)
|
||||
buffer.Reset()
|
||||
|
||||
if cap(buffer.B) < int(length) {
|
||||
buffer.B = make([]byte, length)
|
||||
} else {
|
||||
buffer.B = buffer.B[:length]
|
||||
}
|
||||
// Read the entire message payload
|
||||
|
||||
if _, err := io.ReadFull(conn, buffer.B[:length]); err != nil {
|
||||
return nil, err
|
||||
c.incrementErrors()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read message data: %w", err)
|
||||
}
|
||||
return Deserialize(buffer.B[:length])
|
||||
|
||||
// Deserialize message
|
||||
msg, err := Deserialize(buffer.B[:length])
|
||||
if err != nil {
|
||||
c.incrementErrors()
|
||||
return nil, fmt.Errorf("failed to deserialize message: %w", err)
|
||||
}
|
||||
|
||||
// Validate message
|
||||
if err := msg.Validate(c.config); err != nil {
|
||||
c.incrementErrors()
|
||||
return nil, fmt.Errorf("message validation failed: %w", err)
|
||||
}
|
||||
|
||||
// Handle message fragments if needed
|
||||
if msg.Type == MessageTypeFragment || (msg.Flags&FlagFragmented) != 0 {
|
||||
fm := NewFragmentManager(c, c.config)
|
||||
reassembled, isFragment, err := fm.processFragment(msg)
|
||||
if err != nil {
|
||||
c.incrementErrors()
|
||||
return nil, fmt.Errorf("fragment processing failed: %w", err)
|
||||
}
|
||||
|
||||
// If this is a fragment but reassembly isn't complete yet
|
||||
if isFragment && reassembled == nil {
|
||||
// Update statistics but return nil with no error to indicate
|
||||
// the caller should continue reading messages
|
||||
c.stats.mu.Lock()
|
||||
c.stats.MessagesReceived++
|
||||
c.stats.BytesReceived += uint64(4 + length)
|
||||
c.stats.mu.Unlock()
|
||||
|
||||
// Read the next fragment
|
||||
return c.ReadMessage(ctx, conn)
|
||||
}
|
||||
|
||||
// Use the reassembled message if available
|
||||
if reassembled != nil {
|
||||
msg = reassembled
|
||||
}
|
||||
}
|
||||
|
||||
// Update statistics
|
||||
c.stats.mu.Lock()
|
||||
c.stats.MessagesReceived++
|
||||
c.stats.BytesReceived += uint64(4 + length)
|
||||
c.stats.mu.Unlock()
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// GetStats returns codec statistics
|
||||
func (c *Codec) GetStats() Stats {
|
||||
c.stats.mu.RLock()
|
||||
defer c.stats.mu.RUnlock()
|
||||
return *c.stats
|
||||
}
|
||||
|
||||
// ResetStats resets codec statistics
|
||||
func (c *Codec) ResetStats() {
|
||||
c.stats.mu.Lock()
|
||||
defer c.stats.mu.Unlock()
|
||||
*c.stats = Stats{}
|
||||
}
|
||||
|
||||
// Helper functions for validation
|
||||
func validateCommand(cmd consts.CMD) error {
|
||||
// Add validation based on your command constants
|
||||
if cmd < 0 {
|
||||
return ErrInvalidCommand
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateQueue(queue string) error {
|
||||
if len(queue) == 0 || len(queue) > MaxQueueLength {
|
||||
return ErrInvalidQueue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHeaders(headers map[string]string) error {
|
||||
totalSize := 0
|
||||
for k, v := range headers {
|
||||
totalSize += len(k) + len(v)
|
||||
if totalSize > MaxHeaderSize {
|
||||
return ErrMessageTooLarge
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Codec) incrementErrors() {
|
||||
c.stats.mu.Lock()
|
||||
c.stats.Errors++
|
||||
c.stats.mu.Unlock()
|
||||
}
|
||||
|
||||
// SendMessage Backward compatibility functions
|
||||
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
|
||||
codec := NewCodec(DefaultConfig())
|
||||
return codec.SendMessage(ctx, conn, msg)
|
||||
}
|
||||
|
||||
func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
|
||||
codec := NewCodec(DefaultConfig())
|
||||
return codec.ReadMessage(ctx, conn)
|
||||
}
|
||||
|
282
codec/fragment.go
Normal file
282
codec/fragment.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
// FragmentManager handles message fragmentation and reassembly
|
||||
type FragmentManager struct {
|
||||
codec *Codec
|
||||
config *Config
|
||||
fragmentStore map[string]*fragmentAssembly
|
||||
mu sync.RWMutex
|
||||
cleanupInterval time.Duration
|
||||
cleanupTimer *time.Timer
|
||||
cleanupChan chan struct{}
|
||||
reassemblyTimeout time.Duration
|
||||
}
|
||||
|
||||
// fragmentAssembly holds fragments for a specific message
|
||||
type fragmentAssembly struct {
|
||||
fragments map[uint16][]byte
|
||||
totalSize int
|
||||
createdAt time.Time
|
||||
totalCount uint16
|
||||
messageType MessageType
|
||||
queue string
|
||||
command consts.CMD
|
||||
headers map[string]string
|
||||
id string
|
||||
}
|
||||
|
||||
// NewFragmentManager creates a new fragment manager
|
||||
func NewFragmentManager(codec *Codec, config *Config) *FragmentManager {
|
||||
fm := &FragmentManager{
|
||||
codec: codec,
|
||||
config: config,
|
||||
fragmentStore: make(map[string]*fragmentAssembly),
|
||||
cleanupInterval: time.Minute,
|
||||
reassemblyTimeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
fm.startCleanupTimer()
|
||||
return fm
|
||||
}
|
||||
|
||||
// startCleanupTimer starts a timer to clean up old fragments
|
||||
func (fm *FragmentManager) startCleanupTimer() {
|
||||
fm.cleanupChan = make(chan struct{})
|
||||
fm.cleanupTimer = time.NewTimer(fm.cleanupInterval)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-fm.cleanupTimer.C:
|
||||
fm.cleanupExpiredFragments()
|
||||
fm.cleanupTimer.Reset(fm.cleanupInterval)
|
||||
case <-fm.cleanupChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop stops the fragment manager
|
||||
func (fm *FragmentManager) Stop() {
|
||||
if fm.cleanupTimer != nil {
|
||||
fm.cleanupTimer.Stop()
|
||||
}
|
||||
|
||||
if fm.cleanupChan != nil {
|
||||
close(fm.cleanupChan)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredFragments removes expired fragment assemblies
|
||||
func (fm *FragmentManager) cleanupExpiredFragments() {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, assembly := range fm.fragmentStore {
|
||||
if now.Sub(assembly.createdAt) > fm.reassemblyTimeout {
|
||||
delete(fm.fragmentStore, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fragmentMessage splits a large message into fragments
|
||||
func (fm *FragmentManager) fragmentMessage(ctx context.Context, msg *Message) ([]*Message, error) {
|
||||
if msg == nil {
|
||||
return nil, ErrInvalidMessage
|
||||
}
|
||||
|
||||
// Check if fragmentation is needed
|
||||
if len(msg.Payload) <= int(FragmentationThreshold) {
|
||||
return []*Message{msg}, nil
|
||||
}
|
||||
|
||||
// Calculate how many fragments we need
|
||||
fragmentCount := (len(msg.Payload) + int(FragmentSize) - 1) / int(FragmentSize)
|
||||
if fragmentCount > int(MaxFragments) {
|
||||
return nil, fmt.Errorf("%w: would require %d fragments, maximum is %d",
|
||||
ErrMessageTooLarge, fragmentCount, MaxFragments)
|
||||
}
|
||||
|
||||
// Generate a fragment ID using a hash of message content + timestamp
|
||||
idBytes := []byte(msg.ID)
|
||||
timestampBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timestampBytes, uint64(msg.Timestamp))
|
||||
hashInput := append(idBytes, timestampBytes...)
|
||||
hashInput = append(hashInput, msg.Payload[:min(1024, len(msg.Payload))]...)
|
||||
fragmentID := crc32.ChecksumIEEE(hashInput)
|
||||
|
||||
// Create fragment messages
|
||||
fragments := make([]*Message, fragmentCount)
|
||||
for i := 0; i < fragmentCount; i++ {
|
||||
// Calculate fragment payload boundaries
|
||||
start := i * int(FragmentSize)
|
||||
end := min((i+1)*int(FragmentSize), len(msg.Payload))
|
||||
fragmentPayload := msg.Payload[start:end]
|
||||
|
||||
// Create fragment message
|
||||
fragment := &Message{
|
||||
Headers: copyHeaders(msg.Headers),
|
||||
Queue: msg.Queue,
|
||||
Command: msg.Command,
|
||||
Version: msg.Version,
|
||||
Timestamp: msg.Timestamp,
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeFragment,
|
||||
Flags: msg.Flags | FlagFragmented,
|
||||
FragmentID: fragmentID,
|
||||
Fragments: uint16(fragmentCount),
|
||||
Sequence: uint16(i),
|
||||
Payload: fragmentPayload,
|
||||
}
|
||||
|
||||
fragments[i] = fragment
|
||||
}
|
||||
|
||||
return fragments, nil
|
||||
}
|
||||
|
||||
// sendFragmentedMessage sends a large message as multiple fragments
|
||||
func (fm *FragmentManager) sendFragmentedMessage(ctx context.Context, conn net.Conn, msg *Message) error {
|
||||
fragments, err := fm.fragmentMessage(ctx, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If no fragmentation was needed, send as normal
|
||||
if len(fragments) == 1 && fragments[0].Type != MessageTypeFragment {
|
||||
return fm.codec.SendMessage(ctx, conn, fragments[0])
|
||||
}
|
||||
|
||||
// Send each fragment
|
||||
for _, fragment := range fragments {
|
||||
if err := fm.codec.SendMessage(ctx, conn, fragment); err != nil {
|
||||
return fmt.Errorf("failed to send fragment %d/%d: %w",
|
||||
fragment.Sequence+1, fragment.Fragments, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processFragment processes a fragment and attempts reassembly
|
||||
func (fm *FragmentManager) processFragment(msg *Message) (*Message, bool, error) {
|
||||
if msg == nil || msg.Type != MessageTypeFragment || msg.FragmentID == 0 {
|
||||
return msg, false, nil // Not a fragment, return as is
|
||||
}
|
||||
|
||||
// Generate a unique key for this fragmented message
|
||||
key := fmt.Sprintf("%s-%d-%s", msg.ID, msg.FragmentID, msg.Queue)
|
||||
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
// Check if we already have an assembly for this message
|
||||
assembly, exists := fm.fragmentStore[key]
|
||||
if !exists {
|
||||
// Create a new assembly
|
||||
assembly = &fragmentAssembly{
|
||||
fragments: make(map[uint16][]byte),
|
||||
createdAt: time.Now(),
|
||||
totalCount: msg.Fragments,
|
||||
messageType: MessageTypeStandard,
|
||||
queue: msg.Queue,
|
||||
command: msg.Command,
|
||||
headers: copyHeaders(msg.Headers),
|
||||
id: msg.ID,
|
||||
}
|
||||
fm.fragmentStore[key] = assembly
|
||||
}
|
||||
|
||||
// Store this fragment
|
||||
assembly.fragments[msg.Sequence] = msg.Payload
|
||||
assembly.totalSize += len(msg.Payload)
|
||||
|
||||
// Check if we have all fragments
|
||||
if len(assembly.fragments) == int(assembly.totalCount) {
|
||||
// We have all fragments, reassemble the message
|
||||
reassembled, err := fm.reassembleMessage(key, assembly)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return reassembled, true, nil
|
||||
}
|
||||
|
||||
// Still waiting for more fragments
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
// reassembleMessage combines fragments into the original message
|
||||
func (fm *FragmentManager) reassembleMessage(key string, assembly *fragmentAssembly) (*Message, error) {
|
||||
// Remove the assembly from the store
|
||||
delete(fm.fragmentStore, key)
|
||||
|
||||
// Check if we have all fragments
|
||||
if len(assembly.fragments) != int(assembly.totalCount) {
|
||||
return nil, ErrFragmentMissing
|
||||
}
|
||||
|
||||
// Allocate space for the full payload
|
||||
fullPayload := make([]byte, assembly.totalSize)
|
||||
|
||||
// Combine fragments in order
|
||||
offset := 0
|
||||
for i := uint16(0); i < assembly.totalCount; i++ {
|
||||
fragment, exists := assembly.fragments[i]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("%w: missing fragment %d", ErrFragmentMissing, i)
|
||||
}
|
||||
|
||||
copy(fullPayload[offset:], fragment)
|
||||
offset += len(fragment)
|
||||
}
|
||||
|
||||
// Create the reassembled message
|
||||
msg := &Message{
|
||||
Headers: assembly.headers,
|
||||
Queue: assembly.queue,
|
||||
Command: assembly.command,
|
||||
Version: ProtocolVersion,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ID: assembly.id,
|
||||
Type: assembly.messageType,
|
||||
Flags: FlagNone, // Clear fragmentation flag
|
||||
Payload: fullPayload,
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Helper function to make a copy of headers to avoid shared references
|
||||
func copyHeaders(src map[string]string) map[string]string {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
dst := make(map[string]string, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// min returns the smaller of two integers
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
193
codec/heartbeat.go
Normal file
193
codec/heartbeat.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
// HeartbeatManager manages heartbeat messages for connection health monitoring
|
||||
type HeartbeatManager struct {
|
||||
codec *Codec
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
conn net.Conn
|
||||
stopChan chan struct{}
|
||||
lastHeartbeat atomic.Int64
|
||||
lastReceived atomic.Int64
|
||||
onFailure func(error)
|
||||
mu sync.RWMutex
|
||||
isRunning bool
|
||||
heartbeatsSent uint64
|
||||
heartbeatsRecv uint64
|
||||
failedHeartbeats uint64
|
||||
}
|
||||
|
||||
// NewHeartbeatManager creates a new heartbeat manager
|
||||
func NewHeartbeatManager(codec *Codec, conn net.Conn) *HeartbeatManager {
|
||||
return &HeartbeatManager{
|
||||
codec: codec,
|
||||
interval: 30 * time.Second,
|
||||
timeout: 90 * time.Second, // 3x interval
|
||||
conn: conn,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterval sets the heartbeat interval
|
||||
func (hm *HeartbeatManager) SetInterval(interval time.Duration) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
hm.interval = interval
|
||||
}
|
||||
|
||||
// SetTimeout sets the heartbeat timeout
|
||||
func (hm *HeartbeatManager) SetTimeout(timeout time.Duration) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
hm.timeout = timeout
|
||||
}
|
||||
|
||||
// SetOnFailure sets the callback function for heartbeat failures
|
||||
func (hm *HeartbeatManager) SetOnFailure(fn func(error)) {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
hm.onFailure = fn
|
||||
}
|
||||
|
||||
// Start starts the heartbeat monitoring
|
||||
func (hm *HeartbeatManager) Start() {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if hm.isRunning {
|
||||
return
|
||||
}
|
||||
|
||||
hm.isRunning = true
|
||||
hm.lastHeartbeat.Store(time.Now().Unix())
|
||||
hm.lastReceived.Store(time.Now().Unix())
|
||||
|
||||
go hm.sendHeartbeats()
|
||||
go hm.monitorHeartbeats()
|
||||
}
|
||||
|
||||
// Stop stops the heartbeat monitoring
|
||||
func (hm *HeartbeatManager) Stop() {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if !hm.isRunning {
|
||||
return
|
||||
}
|
||||
|
||||
hm.isRunning = false
|
||||
close(hm.stopChan)
|
||||
// Create a new stop channel for future use
|
||||
hm.stopChan = make(chan struct{})
|
||||
}
|
||||
|
||||
// IsRunning returns whether the heartbeat manager is running
|
||||
func (hm *HeartbeatManager) IsRunning() bool {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
return hm.isRunning
|
||||
}
|
||||
|
||||
// GetStats returns heartbeat statistics
|
||||
func (hm *HeartbeatManager) GetStats() map[string]uint64 {
|
||||
return map[string]uint64{
|
||||
"sent": atomic.LoadUint64(&hm.heartbeatsSent),
|
||||
"received": atomic.LoadUint64(&hm.heartbeatsRecv),
|
||||
"failed": atomic.LoadUint64(&hm.failedHeartbeats),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordHeartbeat records a received heartbeat
|
||||
func (hm *HeartbeatManager) RecordHeartbeat() {
|
||||
hm.lastReceived.Store(time.Now().Unix())
|
||||
atomic.AddUint64(&hm.heartbeatsRecv, 1)
|
||||
}
|
||||
|
||||
// sendHeartbeats sends periodic heartbeat messages
|
||||
func (hm *HeartbeatManager) sendHeartbeats() {
|
||||
ticker := time.NewTicker(hm.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := hm.sendHeartbeat(); err != nil {
|
||||
atomic.AddUint64(&hm.failedHeartbeats, 1)
|
||||
hm.mu.RLock()
|
||||
onFailure := hm.onFailure
|
||||
hm.mu.RUnlock()
|
||||
|
||||
if onFailure != nil {
|
||||
onFailure(err)
|
||||
}
|
||||
}
|
||||
case <-hm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitorHeartbeats monitors the heartbeat health
|
||||
func (hm *HeartbeatManager) monitorHeartbeats() {
|
||||
ticker := time.NewTicker(hm.interval / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now().Unix()
|
||||
lastReceived := hm.lastReceived.Load()
|
||||
|
||||
// Check if we've exceeded the timeout
|
||||
if now-lastReceived > int64(hm.timeout.Seconds()) {
|
||||
err := fmt.Errorf("heartbeat timeout: last received %v seconds ago",
|
||||
now-lastReceived)
|
||||
|
||||
atomic.AddUint64(&hm.failedHeartbeats, 1)
|
||||
hm.mu.RLock()
|
||||
onFailure := hm.onFailure
|
||||
hm.mu.RUnlock()
|
||||
|
||||
if onFailure != nil {
|
||||
onFailure(err)
|
||||
}
|
||||
}
|
||||
case <-hm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHeartbeat sends a single heartbeat message
|
||||
func (hm *HeartbeatManager) sendHeartbeat() error {
|
||||
msg := &Message{
|
||||
Type: MessageTypeHeartbeat,
|
||||
Command: consts.CMD(0), // Use appropriate heartbeat command from your consts
|
||||
Version: ProtocolVersion,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Headers: map[string]string{"type": "heartbeat"},
|
||||
Payload: []byte{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := hm.codec.SendMessage(ctx, hm.conn, msg)
|
||||
if err == nil {
|
||||
hm.lastHeartbeat.Store(time.Now().Unix())
|
||||
atomic.AddUint64(&hm.heartbeatsSent, 1)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
@@ -1,39 +1,418 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type MarshallerFunc func(v any) ([]byte, error)
|
||||
// Error definitions for serialization
|
||||
var (
|
||||
ErrSerializationFailed = errors.New("serialization failed")
|
||||
ErrDeserializationFailed = errors.New("deserialization failed")
|
||||
ErrCompressionFailed = errors.New("compression failed")
|
||||
ErrDecompressionFailed = errors.New("decompression failed")
|
||||
ErrEncryptionFailed = errors.New("encryption failed")
|
||||
ErrDecryptionFailed = errors.New("decryption failed")
|
||||
ErrInvalidKey = errors.New("invalid encryption key")
|
||||
)
|
||||
|
||||
type UnmarshallerFunc func(data []byte, v any) error
|
||||
// ContentType represents the content type of serialized data
|
||||
type ContentType string
|
||||
|
||||
const (
|
||||
ContentTypeJSON ContentType = "application/json"
|
||||
ContentTypeMsgPack ContentType = "application/msgpack"
|
||||
ContentTypeCBOR ContentType = "application/cbor"
|
||||
)
|
||||
|
||||
// Marshaller interface for pluggable serialization
|
||||
type Marshaller interface {
|
||||
Marshal(v any) ([]byte, error)
|
||||
ContentType() ContentType
|
||||
}
|
||||
|
||||
// Unmarshaller interface for pluggable deserialization
|
||||
type Unmarshaller interface {
|
||||
Unmarshal(data []byte, v any) error
|
||||
ContentType() ContentType
|
||||
}
|
||||
|
||||
// MarshallerFunc adapter
|
||||
type MarshallerFunc func(v any) ([]byte, error)
|
||||
|
||||
func (f MarshallerFunc) Marshal(v any) ([]byte, error) {
|
||||
return f(v)
|
||||
}
|
||||
|
||||
func (f MarshallerFunc) ContentType() ContentType {
|
||||
return ContentTypeJSON
|
||||
}
|
||||
|
||||
// UnmarshallerFunc adapter
|
||||
type UnmarshallerFunc func(data []byte, v any) error
|
||||
|
||||
func (f UnmarshallerFunc) Unmarshal(data []byte, v any) error {
|
||||
return f(data, v)
|
||||
}
|
||||
|
||||
var defaultMarshaller MarshallerFunc = json.Marshal
|
||||
|
||||
var defaultUnmarshaller UnmarshallerFunc = func(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
func (f UnmarshallerFunc) ContentType() ContentType {
|
||||
return ContentTypeJSON
|
||||
}
|
||||
|
||||
// SerializationConfig holds serialization configuration
|
||||
type SerializationConfig struct {
|
||||
EnableCompression bool
|
||||
CompressionLevel int
|
||||
MaxCompressionRatio float64
|
||||
EnableEncryption bool
|
||||
EncryptionKey []byte
|
||||
PreferredCipher string // "chacha20poly1305" or "aes-gcm"
|
||||
}
|
||||
|
||||
// DefaultSerializationConfig returns default configuration
|
||||
func DefaultSerializationConfig() *SerializationConfig {
|
||||
return &SerializationConfig{
|
||||
EnableCompression: false,
|
||||
CompressionLevel: gzip.DefaultCompression,
|
||||
MaxCompressionRatio: 0.8, // Only compress if we save at least 20%
|
||||
EnableEncryption: false,
|
||||
PreferredCipher: "chacha20poly1305",
|
||||
}
|
||||
}
|
||||
|
||||
// SerializationManager manages serialization with configuration
|
||||
type SerializationManager struct {
|
||||
marshaller Marshaller
|
||||
unmarshaller Unmarshaller
|
||||
config *SerializationConfig
|
||||
mu sync.RWMutex
|
||||
cachedCipher cipher.AEAD
|
||||
cipherMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewSerializationManager creates a new serialization manager
|
||||
func NewSerializationManager(config *SerializationConfig) *SerializationManager {
|
||||
if config == nil {
|
||||
config = DefaultSerializationConfig()
|
||||
}
|
||||
|
||||
return &SerializationManager{
|
||||
marshaller: MarshallerFunc(json.Marshal),
|
||||
unmarshaller: UnmarshallerFunc(func(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMarshaller sets custom marshaller
|
||||
func (sm *SerializationManager) SetMarshaller(marshaller Marshaller) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.marshaller = marshaller
|
||||
}
|
||||
|
||||
// SetUnmarshaller sets custom unmarshaller
|
||||
func (sm *SerializationManager) SetUnmarshaller(unmarshaller Unmarshaller) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.unmarshaller = unmarshaller
|
||||
}
|
||||
|
||||
// SetEncryptionKey sets the encryption key
|
||||
func (sm *SerializationManager) SetEncryptionKey(key []byte) error {
|
||||
if sm.config.EnableEncryption && len(key) == 0 {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.config.EncryptionKey = key
|
||||
// Clear the cached cipher so it will be recreated with the new key
|
||||
sm.cipherMu.Lock()
|
||||
defer sm.cipherMu.Unlock()
|
||||
sm.cachedCipher = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal serializes data with optional compression and encryption
|
||||
func (sm *SerializationManager) Marshal(v any) ([]byte, error) {
|
||||
sm.mu.RLock()
|
||||
marshaller := sm.marshaller
|
||||
config := sm.config
|
||||
sm.mu.RUnlock()
|
||||
|
||||
// Serialize the data
|
||||
data, err := marshaller.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrSerializationFailed, err)
|
||||
}
|
||||
|
||||
// Apply compression if enabled and beneficial
|
||||
if config.EnableCompression && len(data) > 256 { // Only compress larger payloads
|
||||
compressed, err := sm.compress(data)
|
||||
if err != nil {
|
||||
// Continue with uncompressed data on error
|
||||
// but don't return error to allow for graceful degradation
|
||||
} else if float64(len(compressed))/float64(len(data)) <= config.MaxCompressionRatio {
|
||||
// Only use compression if it provides significant benefit
|
||||
data = compressed
|
||||
}
|
||||
}
|
||||
|
||||
// Apply encryption if enabled
|
||||
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
|
||||
encrypted, err := sm.encrypt(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrEncryptionFailed, err)
|
||||
}
|
||||
data = encrypted
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Unmarshal deserializes data with optional decompression and decryption
|
||||
func (sm *SerializationManager) Unmarshal(data []byte, v any) error {
|
||||
if len(data) == 0 {
|
||||
return fmt.Errorf("%w: empty data", ErrDeserializationFailed)
|
||||
}
|
||||
|
||||
sm.mu.RLock()
|
||||
unmarshaller := sm.unmarshaller
|
||||
config := sm.config
|
||||
sm.mu.RUnlock()
|
||||
|
||||
// Apply decryption if enabled
|
||||
if config.EnableEncryption && len(config.EncryptionKey) > 0 {
|
||||
decrypted, err := sm.decrypt(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrDecryptionFailed, err)
|
||||
}
|
||||
data = decrypted
|
||||
}
|
||||
|
||||
// Try decompression if enabled
|
||||
if config.EnableCompression && sm.isCompressed(data) {
|
||||
decompressed, err := sm.decompress(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrDecompressionFailed, err)
|
||||
}
|
||||
data = decompressed
|
||||
}
|
||||
|
||||
// Deserialize the data
|
||||
if err := unmarshaller.Unmarshal(data, v); err != nil {
|
||||
return fmt.Errorf("%w: %v", ErrDeserializationFailed, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// compress compresses data using gzip
|
||||
func (sm *SerializationManager) compress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Add compression marker
|
||||
buf.WriteByte(0x1f) // gzip magic number
|
||||
|
||||
writer, err := gzip.NewWriterLevel(&buf, sm.config.CompressionLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := writer.Write(data); err != nil {
|
||||
writer.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// decompress decompresses gzip data
|
||||
func (sm *SerializationManager) decompress(data []byte) ([]byte, error) {
|
||||
if len(data) < 1 {
|
||||
return nil, fmt.Errorf("invalid compressed data")
|
||||
}
|
||||
|
||||
// Remove compression marker
|
||||
if data[0] != 0x1f {
|
||||
return data, nil // Not compressed
|
||||
}
|
||||
|
||||
reader, err := gzip.NewReader(bytes.NewReader(data[1:]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := buf.ReadFrom(reader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// isCompressed checks if data is compressed
|
||||
func (sm *SerializationManager) isCompressed(data []byte) bool {
|
||||
return len(data) > 0 && data[0] == 0x1f
|
||||
}
|
||||
|
||||
// getCipher returns a cipher.AEAD instance for encryption/decryption
|
||||
func (sm *SerializationManager) getCipher() (cipher.AEAD, error) {
|
||||
sm.cipherMu.Lock()
|
||||
defer sm.cipherMu.Unlock()
|
||||
|
||||
if sm.cachedCipher != nil {
|
||||
return sm.cachedCipher, nil
|
||||
}
|
||||
|
||||
var aead cipher.AEAD
|
||||
var err error
|
||||
|
||||
sm.mu.RLock()
|
||||
key := sm.config.EncryptionKey
|
||||
preferredCipher := sm.config.PreferredCipher
|
||||
sm.mu.RUnlock()
|
||||
|
||||
switch preferredCipher {
|
||||
case "chacha20poly1305":
|
||||
aead, err = chacha20poly1305.New(key)
|
||||
case "aes-gcm":
|
||||
block, e := aes.NewCipher(key)
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
aead, err = cipher.NewGCM(block)
|
||||
default:
|
||||
// Default to ChaCha20-Poly1305
|
||||
aead, err = chacha20poly1305.New(key)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sm.cachedCipher = aead
|
||||
return aead, nil
|
||||
}
|
||||
|
||||
// encrypt encrypts data using authenticated encryption
|
||||
func (sm *SerializationManager) encrypt(data []byte) ([]byte, error) {
|
||||
aead, err := sm.getCipher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add timestamp to associated data for replay protection
|
||||
timestampBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(timestampBytes, uint64(time.Now().Unix()))
|
||||
|
||||
// Encrypt and authenticate
|
||||
ciphertext := aead.Seal(nil, nonce, data, timestampBytes)
|
||||
|
||||
// Prepend nonce and timestamp to ciphertext
|
||||
result := make([]byte, len(nonce)+len(timestampBytes)+len(ciphertext))
|
||||
copy(result, nonce)
|
||||
copy(result[len(nonce):], timestampBytes)
|
||||
copy(result[len(nonce)+len(timestampBytes):], ciphertext)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// decrypt decrypts data using authenticated decryption
|
||||
func (sm *SerializationManager) decrypt(data []byte) ([]byte, error) {
|
||||
aead, err := sm.getCipher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract nonce
|
||||
nonceSize := aead.NonceSize()
|
||||
if len(data) < nonceSize+8 { // 8 bytes for timestamp
|
||||
return nil, fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce := data[:nonceSize]
|
||||
timestampBytes := data[nonceSize : nonceSize+8]
|
||||
ciphertext := data[nonceSize+8:]
|
||||
|
||||
// Decrypt and verify
|
||||
plaintext, err := aead.Open(nil, nonce, ciphertext, timestampBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Global serialization manager instance
|
||||
var globalSerializationManager = NewSerializationManager(nil)
|
||||
var globalSerializationManagerMu sync.RWMutex
|
||||
|
||||
// Global functions for backward compatibility
|
||||
func SetMarshaller(marshaller MarshallerFunc) {
|
||||
defaultMarshaller = marshaller
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.SetMarshaller(marshaller)
|
||||
}
|
||||
|
||||
func SetUnmarshaller(unmarshaller UnmarshallerFunc) {
|
||||
defaultUnmarshaller = unmarshaller
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.SetUnmarshaller(unmarshaller)
|
||||
}
|
||||
|
||||
func EnableCompression(enable bool) {
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.config.EnableCompression = enable
|
||||
}
|
||||
|
||||
func EnableEncryption(enable bool, key []byte) error {
|
||||
globalSerializationManagerMu.Lock()
|
||||
defer globalSerializationManagerMu.Unlock()
|
||||
globalSerializationManager.config.EnableEncryption = enable
|
||||
if enable {
|
||||
return globalSerializationManager.SetEncryptionKey(key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return defaultMarshaller(v)
|
||||
globalSerializationManagerMu.RLock()
|
||||
defer globalSerializationManagerMu.RUnlock()
|
||||
return globalSerializationManager.Marshal(v)
|
||||
}
|
||||
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return defaultUnmarshaller(data, v)
|
||||
globalSerializationManagerMu.RLock()
|
||||
defer globalSerializationManagerMu.RUnlock()
|
||||
return globalSerializationManager.Unmarshal(data, v)
|
||||
}
|
||||
|
210
codec/testing.go
Normal file
210
codec/testing.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
// MockConn implements net.Conn for testing
|
||||
type MockConn struct {
|
||||
ReadBuffer *bytes.Buffer
|
||||
WriteBuffer *bytes.Buffer
|
||||
ReadDelay time.Duration
|
||||
WriteDelay time.Duration
|
||||
IsClosed bool
|
||||
ReadErr error
|
||||
WriteErr error
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewMockConn creates a new mock connection
|
||||
func NewMockConn() *MockConn {
|
||||
return &MockConn{
|
||||
ReadBuffer: bytes.NewBuffer(nil),
|
||||
WriteBuffer: bytes.NewBuffer(nil),
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the net.Conn Read method
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.IsClosed {
|
||||
return 0, errors.New("connection closed")
|
||||
}
|
||||
|
||||
if m.ReadErr != nil {
|
||||
return 0, m.ReadErr
|
||||
}
|
||||
|
||||
if m.ReadDelay > 0 {
|
||||
time.Sleep(m.ReadDelay)
|
||||
}
|
||||
|
||||
return m.ReadBuffer.Read(b)
|
||||
}
|
||||
|
||||
// Write implements the net.Conn Write method
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.IsClosed {
|
||||
return 0, errors.New("connection closed")
|
||||
}
|
||||
|
||||
if m.WriteErr != nil {
|
||||
return 0, m.WriteErr
|
||||
}
|
||||
|
||||
if m.WriteDelay > 0 {
|
||||
time.Sleep(m.WriteDelay)
|
||||
}
|
||||
|
||||
return m.WriteBuffer.Write(b)
|
||||
}
|
||||
|
||||
// Close implements the net.Conn Close method
|
||||
func (m *MockConn) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.IsClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr implements the net.Conn LocalAddr method
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
|
||||
}
|
||||
|
||||
// RemoteAddr implements the net.Conn RemoteAddr method
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
|
||||
}
|
||||
|
||||
// SetDeadline implements the net.Conn SetDeadline method
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the net.Conn SetReadDeadline method
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the net.Conn SetWriteDeadline method
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CodecTestSuite provides utilities for testing the codec
|
||||
type CodecTestSuite struct {
|
||||
Codec *Codec
|
||||
Config *Config
|
||||
}
|
||||
|
||||
// NewCodecTestSuite creates a new codec test suite
|
||||
func NewCodecTestSuite() *CodecTestSuite {
|
||||
config := DefaultConfig()
|
||||
// Set smaller timeouts for testing
|
||||
config.ReadTimeout = 500 * time.Millisecond
|
||||
config.WriteTimeout = 500 * time.Millisecond
|
||||
|
||||
return &CodecTestSuite{
|
||||
Codec: NewCodec(config),
|
||||
Config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// SendReceiveTest tests sending and receiving a message
|
||||
func (ts *CodecTestSuite) SendReceiveTest(msg *Message) error {
|
||||
conn := NewMockConn()
|
||||
|
||||
// Send the message
|
||||
ctx := context.Background()
|
||||
if err := ts.Codec.SendMessage(ctx, conn, msg); err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
// Move written data to read buffer to simulate network transport
|
||||
conn.ReadBuffer.Write(conn.WriteBuffer.Bytes())
|
||||
conn.WriteBuffer.Reset()
|
||||
|
||||
// Receive the message
|
||||
received, err := ts.Codec.ReadMessage(ctx, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to receive message: %w", err)
|
||||
}
|
||||
|
||||
// Validate the message
|
||||
if received.Command != msg.Command {
|
||||
return fmt.Errorf("command mismatch: got %v, want %v", received.Command, msg.Command)
|
||||
}
|
||||
|
||||
if received.Queue != msg.Queue {
|
||||
return fmt.Errorf("queue mismatch: got %v, want %v", received.Queue, msg.Queue)
|
||||
}
|
||||
|
||||
if !bytes.Equal(received.Payload, msg.Payload) {
|
||||
return fmt.Errorf("payload mismatch: got %d bytes, want %d bytes", len(received.Payload), len(msg.Payload))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FragmentationTest tests the fragmentation and reassembly of large messages
|
||||
func (ts *CodecTestSuite) FragmentationTest(payload []byte) error {
|
||||
msg := &Message{
|
||||
Command: consts.CMD(1), // Use appropriate command from your consts
|
||||
Queue: "test_queue",
|
||||
Headers: map[string]string{"test": "header"},
|
||||
Payload: payload,
|
||||
Version: ProtocolVersion,
|
||||
Timestamp: time.Now().Unix(),
|
||||
ID: "test-message-id",
|
||||
}
|
||||
|
||||
conn := NewMockConn()
|
||||
|
||||
// Configure fragmentation
|
||||
fm := NewFragmentManager(ts.Codec, ts.Config)
|
||||
|
||||
// Send the fragmented message
|
||||
ctx := context.Background()
|
||||
if err := fm.sendFragmentedMessage(ctx, conn, msg); err != nil {
|
||||
return fmt.Errorf("failed to send fragmented message: %w", err)
|
||||
}
|
||||
|
||||
// Move written data to read buffer to simulate network transport
|
||||
conn.ReadBuffer.Write(conn.WriteBuffer.Bytes())
|
||||
conn.WriteBuffer.Reset()
|
||||
|
||||
// Receive and reassemble the message
|
||||
received, err := ts.Codec.ReadMessage(ctx, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to receive message: %w", err)
|
||||
}
|
||||
|
||||
// Validate the reassembled message
|
||||
if received.Command != msg.Command {
|
||||
return fmt.Errorf("command mismatch: got %v, want %v", received.Command, msg.Command)
|
||||
}
|
||||
|
||||
if received.Queue != msg.Queue {
|
||||
return fmt.Errorf("queue mismatch: got %v, want %v", received.Queue, msg.Queue)
|
||||
}
|
||||
|
||||
if !bytes.Equal(received.Payload, msg.Payload) {
|
||||
return fmt.Errorf("payload mismatch: got %d bytes, want %d bytes", len(received.Payload), len(msg.Payload))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -1,99 +0,0 @@
|
||||
{
|
||||
"broker": {
|
||||
"address": "localhost",
|
||||
"port": 8080,
|
||||
"max_connections": 1000,
|
||||
"connection_timeout": "5s",
|
||||
"read_timeout": "300s",
|
||||
"write_timeout": "30s",
|
||||
"idle_timeout": "600s",
|
||||
"keep_alive": true,
|
||||
"keep_alive_period": "60s",
|
||||
"max_queue_depth": 10000,
|
||||
"enable_dead_letter": true,
|
||||
"dead_letter_max_retries": 3
|
||||
},
|
||||
"consumer": {
|
||||
"enable_http_api": true,
|
||||
"max_retries": 5,
|
||||
"initial_delay": "2s",
|
||||
"max_backoff": "30s",
|
||||
"jitter_percent": 0.5,
|
||||
"batch_size": 10,
|
||||
"prefetch_count": 100,
|
||||
"auto_ack": false,
|
||||
"requeue_on_failure": true
|
||||
},
|
||||
"publisher": {
|
||||
"enable_http_api": true,
|
||||
"max_retries": 3,
|
||||
"initial_delay": "1s",
|
||||
"max_backoff": "10s",
|
||||
"confirm_delivery": true,
|
||||
"publish_timeout": "5s",
|
||||
"connection_pool_size": 10
|
||||
},
|
||||
"pool": {
|
||||
"queue_size": 1000,
|
||||
"max_workers": 20,
|
||||
"max_memory_load": 1073741824,
|
||||
"idle_timeout": "300s",
|
||||
"graceful_shutdown_timeout": "30s",
|
||||
"task_timeout": "60s",
|
||||
"enable_metrics": true,
|
||||
"enable_diagnostics": true
|
||||
},
|
||||
"security": {
|
||||
"enable_tls": false,
|
||||
"tls_cert_path": "./certs/server.crt",
|
||||
"tls_key_path": "./certs/server.key",
|
||||
"tls_ca_path": "./certs/ca.crt",
|
||||
"enable_auth": false,
|
||||
"auth_provider": "jwt",
|
||||
"jwt_secret": "your-secret-key",
|
||||
"enable_encryption": false,
|
||||
"encryption_key": "32-byte-encryption-key-here!!"
|
||||
},
|
||||
"monitoring": {
|
||||
"metrics_port": 9090,
|
||||
"health_check_port": 9091,
|
||||
"enable_metrics": true,
|
||||
"enable_health_checks": true,
|
||||
"metrics_interval": "10s",
|
||||
"health_check_interval": "30s",
|
||||
"retention_period": "24h",
|
||||
"enable_tracing": true,
|
||||
"jaeger_endpoint": "http://localhost:14268/api/traces"
|
||||
},
|
||||
"persistence": {
|
||||
"enable": true,
|
||||
"provider": "postgres",
|
||||
"connection_string": "postgres://user:password@localhost:5432/mq_db?sslmode=disable",
|
||||
"max_connections": 50,
|
||||
"connection_timeout": "30s",
|
||||
"enable_migrations": true,
|
||||
"backup_enabled": true,
|
||||
"backup_interval": "6h"
|
||||
},
|
||||
"clustering": {
|
||||
"enable": false,
|
||||
"node_id": "node-1",
|
||||
"cluster_name": "mq-cluster",
|
||||
"peers": [ ],
|
||||
"election_timeout": "5s",
|
||||
"heartbeat_interval": "1s",
|
||||
"enable_auto_discovery": false,
|
||||
"discovery_port": 7946
|
||||
},
|
||||
"rate_limit": {
|
||||
"broker_rate": 1000,
|
||||
"broker_burst": 100,
|
||||
"consumer_rate": 500,
|
||||
"consumer_burst": 50,
|
||||
"publisher_rate": 200,
|
||||
"publisher_burst": 20,
|
||||
"global_rate": 2000,
|
||||
"global_burst": 200
|
||||
},
|
||||
"last_updated": "2025-07-29T00:00:00Z"
|
||||
}
|
37
consumer.go
37
consumer.go
@@ -152,7 +152,10 @@ func (c *Consumer) Metrics() Metrics {
|
||||
|
||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||||
msg, err := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating subscribe message: %v", err)
|
||||
}
|
||||
if err := c.send(ctx, c.conn, msg); err != nil {
|
||||
return fmt.Errorf("error while trying to subscribe: %v", err)
|
||||
}
|
||||
@@ -207,7 +210,14 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C
|
||||
func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||
reply, err := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to create MESSAGE_ACK",
|
||||
logger.Field{Key: "queue", Value: msg.Queue},
|
||||
logger.Field{Key: "task_id", Value: taskID},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Send with timeout to avoid blocking
|
||||
sendCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
@@ -375,7 +385,14 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||||
}
|
||||
}
|
||||
bt, _ := json.Marshal(result)
|
||||
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
||||
reply, err := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to create MESSAGE_RESPONSE",
|
||||
logger.Field{Key: "topic", Value: result.Topic},
|
||||
logger.Field{Key: "task_id", Value: result.TaskID},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@@ -395,7 +412,14 @@ func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, er
|
||||
// Send deny message asynchronously to avoid blocking
|
||||
go func() {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||||
reply, err := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to create MESSAGE_DENY",
|
||||
logger.Field{Key: "queue", Value: queue},
|
||||
logger.Field{Key: "task_id", Value: taskID},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sendCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@@ -820,7 +844,10 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
|
||||
|
||||
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
||||
msg, err := codec.NewMessage(cmd, nil, c.queue, headers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating operation message: %v", err)
|
||||
}
|
||||
return c.send(ctx, c.conn, msg)
|
||||
}
|
||||
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/oarkflow/json"
|
||||
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
|
||||
"github.com/oarkflow/jet"
|
||||
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
|
||||
func main() {
|
||||
flow := dag.NewDAG("Email Notification System", "email-notification", func(taskID string, result mq.Result) {
|
||||
fmt.Printf("Email notification workflow completed for task %s: %s\n", taskID, string(result.Payload))
|
||||
fmt.Printf("Email notification workflow completed for task %s: %s\n", taskID, string(utils.RemoveRecursiveFromJSON(result.Payload, "html_content")))
|
||||
})
|
||||
|
||||
// Add workflow nodes
|
||||
|
@@ -1,557 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"github.com/oarkflow/mq/logger"
|
||||
)
|
||||
|
||||
// ExampleProcessor demonstrates a custom processor with debugging
|
||||
type ExampleProcessor struct {
|
||||
name string
|
||||
tags []string
|
||||
}
|
||||
|
||||
func NewExampleProcessor(name string) *ExampleProcessor {
|
||||
return &ExampleProcessor{
|
||||
name: name,
|
||||
tags: []string{"example", "demo"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ExampleProcessor) ProcessTask(ctx context.Context, task *mq.Task) mq.Result {
|
||||
// Simulate processing time
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Add some example processing logic
|
||||
var data map[string]interface{}
|
||||
if err := task.UnmarshalPayload(&data); err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
|
||||
// Process the data
|
||||
data["processed_by"] = p.name
|
||||
data["processed_at"] = time.Now()
|
||||
|
||||
payload, _ := task.MarshalPayload(data)
|
||||
return mq.Result{Payload: payload}
|
||||
}
|
||||
|
||||
func (p *ExampleProcessor) SetConfig(payload dag.Payload) {}
|
||||
func (p *ExampleProcessor) SetTags(tags ...string) { p.tags = append(p.tags, tags...) }
|
||||
func (p *ExampleProcessor) GetTags() []string { return p.tags }
|
||||
func (p *ExampleProcessor) Consume(ctx context.Context) error { return nil }
|
||||
func (p *ExampleProcessor) Pause(ctx context.Context) error { return nil }
|
||||
func (p *ExampleProcessor) Resume(ctx context.Context) error { return nil }
|
||||
func (p *ExampleProcessor) Stop(ctx context.Context) error { return nil }
|
||||
func (p *ExampleProcessor) Close() error { return nil }
|
||||
func (p *ExampleProcessor) GetType() string { return "example" }
|
||||
func (p *ExampleProcessor) GetKey() string { return p.name }
|
||||
func (p *ExampleProcessor) SetKey(key string) { p.name = key }
|
||||
|
||||
// CustomActivityHook demonstrates custom activity processing
|
||||
type CustomActivityHook struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (h *CustomActivityHook) OnActivity(entry dag.ActivityEntry) error {
|
||||
// Custom processing of activity entries
|
||||
if entry.Level == dag.ActivityLevelError {
|
||||
h.logger.Error("Critical activity detected",
|
||||
logger.Field{Key: "activity_id", Value: entry.ID},
|
||||
logger.Field{Key: "dag_name", Value: entry.DAGName},
|
||||
logger.Field{Key: "message", Value: entry.Message},
|
||||
)
|
||||
|
||||
// Here you could send notifications, trigger alerts, etc.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CustomAlertHandler demonstrates custom alert handling
|
||||
type CustomAlertHandler struct {
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (h *CustomAlertHandler) HandleAlert(alert dag.Alert) error {
|
||||
h.logger.Warn("DAG Alert received",
|
||||
logger.Field{Key: "type", Value: alert.Type},
|
||||
logger.Field{Key: "severity", Value: alert.Severity},
|
||||
logger.Field{Key: "message", Value: alert.Message},
|
||||
)
|
||||
|
||||
// Here you could integrate with external alerting systems
|
||||
// like Slack, PagerDuty, email, etc.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Initialize logger
|
||||
log := logger.New(logger.Config{
|
||||
Level: logger.LevelInfo,
|
||||
Format: logger.FormatJSON,
|
||||
})
|
||||
|
||||
// Create a comprehensive DAG with all enhanced features
|
||||
server := mq.NewServer("demo", ":0", log)
|
||||
|
||||
// Create DAG with comprehensive configuration
|
||||
dagInstance := dag.NewDAG("production-workflow", "workflow-key", func(ctx context.Context, result mq.Result) {
|
||||
log.Info("Workflow completed",
|
||||
logger.Field{Key: "result", Value: string(result.Payload)},
|
||||
)
|
||||
})
|
||||
|
||||
// Initialize all enhanced components
|
||||
setupEnhancedDAG(dagInstance, log)
|
||||
|
||||
// Build the workflow
|
||||
buildWorkflow(dagInstance, log)
|
||||
|
||||
// Start the server and DAG
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := server.Start(ctx); err != nil {
|
||||
log.Error("Server failed to start", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Start enhanced DAG features
|
||||
startEnhancedFeatures(ctx, dagInstance, log)
|
||||
|
||||
// Set up HTTP API for monitoring and management
|
||||
setupHTTPAPI(dagInstance, log)
|
||||
|
||||
// Start the HTTP server
|
||||
go func() {
|
||||
log.Info("Starting HTTP server on :8080")
|
||||
if err := http.ListenAndServe(":8080", nil); err != nil {
|
||||
log.Error("HTTP server failed", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
|
||||
// Demonstrate the enhanced features
|
||||
demonstrateFeatures(ctx, dagInstance, log)
|
||||
|
||||
// Wait for shutdown signal
|
||||
waitForShutdown(ctx, cancel, dagInstance, server, log)
|
||||
}
|
||||
|
||||
func setupEnhancedDAG(dagInstance *dag.DAG, log logger.Logger) {
|
||||
// Initialize activity logger with memory persistence
|
||||
activityConfig := dag.DefaultActivityLoggerConfig()
|
||||
activityConfig.BufferSize = 500
|
||||
activityConfig.FlushInterval = 2 * time.Second
|
||||
|
||||
persistence := dag.NewMemoryActivityPersistence()
|
||||
dagInstance.InitializeActivityLogger(activityConfig, persistence)
|
||||
|
||||
// Add custom activity hook
|
||||
customHook := &CustomActivityHook{logger: log}
|
||||
dagInstance.AddActivityHook(customHook)
|
||||
|
||||
// Initialize monitoring with comprehensive configuration
|
||||
monitorConfig := dag.MonitoringConfig{
|
||||
MetricsInterval: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
BufferSize: 1000,
|
||||
}
|
||||
|
||||
alertThresholds := &dag.AlertThresholds{
|
||||
MaxFailureRate: 0.1, // 10%
|
||||
MaxExecutionTime: 30 * time.Second,
|
||||
MaxTasksInProgress: 100,
|
||||
MinSuccessRate: 0.9, // 90%
|
||||
MaxNodeFailures: 5,
|
||||
HealthCheckInterval: 10 * time.Second,
|
||||
}
|
||||
|
||||
dagInstance.InitializeMonitoring(monitorConfig, alertThresholds)
|
||||
|
||||
// Add custom alert handler
|
||||
customAlertHandler := &CustomAlertHandler{logger: log}
|
||||
dagInstance.AddAlertHandler(customAlertHandler)
|
||||
|
||||
// Initialize configuration management
|
||||
dagInstance.InitializeConfigManager()
|
||||
|
||||
// Set up rate limiting
|
||||
dagInstance.InitializeRateLimiter()
|
||||
dagInstance.SetRateLimit("validate", 10.0, 5) // 10 req/sec, burst 5
|
||||
dagInstance.SetRateLimit("process", 20.0, 10) // 20 req/sec, burst 10
|
||||
dagInstance.SetRateLimit("finalize", 5.0, 2) // 5 req/sec, burst 2
|
||||
|
||||
// Initialize retry management
|
||||
retryConfig := &dag.RetryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 1 * time.Second,
|
||||
MaxDelay: 10 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
Jitter: true,
|
||||
RetryCondition: func(err error) bool {
|
||||
// Custom retry condition - retry on specific errors
|
||||
return err != nil && err.Error() != "permanent_failure"
|
||||
},
|
||||
}
|
||||
dagInstance.InitializeRetryManager(retryConfig)
|
||||
|
||||
// Initialize transaction management
|
||||
txConfig := dag.TransactionConfig{
|
||||
DefaultTimeout: 5 * time.Minute,
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
}
|
||||
dagInstance.InitializeTransactionManager(txConfig)
|
||||
|
||||
// Initialize cleanup management
|
||||
cleanupConfig := dag.CleanupConfig{
|
||||
Interval: 5 * time.Minute,
|
||||
TaskRetentionPeriod: 1 * time.Hour,
|
||||
ResultRetentionPeriod: 2 * time.Hour,
|
||||
MaxRetainedTasks: 1000,
|
||||
}
|
||||
dagInstance.InitializeCleanupManager(cleanupConfig)
|
||||
|
||||
// Initialize performance optimizer
|
||||
dagInstance.InitializePerformanceOptimizer()
|
||||
|
||||
// Set up webhook manager for external notifications
|
||||
httpClient := dag.NewSimpleHTTPClient(30 * time.Second)
|
||||
webhookManager := dag.NewWebhookManager(httpClient, log)
|
||||
|
||||
// Add webhook for task completion events
|
||||
webhookConfig := dag.WebhookConfig{
|
||||
URL: "https://api.example.com/dag-events", // Replace with actual endpoint
|
||||
Headers: map[string]string{"Authorization": "Bearer your-token"},
|
||||
RetryCount: 3,
|
||||
Events: []string{"task_completed", "task_failed", "dag_completed"},
|
||||
}
|
||||
webhookManager.AddWebhook("task_completed", webhookConfig)
|
||||
dagInstance.SetWebhookManager(webhookManager)
|
||||
|
||||
log.Info("Enhanced DAG features initialized successfully")
|
||||
}
|
||||
|
||||
func buildWorkflow(dagInstance *dag.DAG, log logger.Logger) {
|
||||
// Create processors for each step
|
||||
validator := NewExampleProcessor("validator")
|
||||
processor := NewExampleProcessor("processor")
|
||||
enricher := NewExampleProcessor("enricher")
|
||||
finalizer := NewExampleProcessor("finalizer")
|
||||
|
||||
// Build the workflow with retry configurations
|
||||
retryConfig := &dag.RetryConfig{
|
||||
MaxRetries: 2,
|
||||
InitialDelay: 500 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
}
|
||||
|
||||
dagInstance.
|
||||
AddNodeWithRetry(dag.Function, "Validate Input", "validate", validator, retryConfig, true).
|
||||
AddNodeWithRetry(dag.Function, "Process Data", "process", processor, retryConfig).
|
||||
AddNodeWithRetry(dag.Function, "Enrich Data", "enrich", enricher, retryConfig).
|
||||
AddNodeWithRetry(dag.Function, "Finalize", "finalize", finalizer, retryConfig).
|
||||
Connect("validate", "process").
|
||||
Connect("process", "enrich").
|
||||
Connect("enrich", "finalize")
|
||||
|
||||
// Add conditional connections
|
||||
dagInstance.AddCondition("validate", "success", "process")
|
||||
dagInstance.AddCondition("validate", "failure", "finalize") // Skip to finalize on validation failure
|
||||
|
||||
// Validate the DAG structure
|
||||
if err := dagInstance.ValidateDAG(); err != nil {
|
||||
log.Error("DAG validation failed", logger.Field{Key: "error", Value: err.Error()})
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Info("Workflow built and validated successfully")
|
||||
}
|
||||
|
||||
func startEnhancedFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) {
|
||||
// Start monitoring
|
||||
dagInstance.StartMonitoring(ctx)
|
||||
|
||||
// Start cleanup manager
|
||||
dagInstance.StartCleanup(ctx)
|
||||
|
||||
// Enable batch processing
|
||||
dagInstance.SetBatchProcessingEnabled(true)
|
||||
|
||||
log.Info("Enhanced features started")
|
||||
}
|
||||
|
||||
func setupHTTPAPI(dagInstance *dag.DAG, log logger.Logger) {
|
||||
// Set up standard DAG handlers
|
||||
dagInstance.Handlers(http.DefaultServeMux, "/dag")
|
||||
|
||||
// Set up enhanced API endpoints
|
||||
enhancedAPI := dag.NewEnhancedAPIHandler(dagInstance)
|
||||
enhancedAPI.RegisterRoutes(http.DefaultServeMux)
|
||||
|
||||
// Custom endpoints for demonstration
|
||||
http.HandleFunc("/demo/activities", func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := dag.ActivityFilter{
|
||||
Limit: 50,
|
||||
}
|
||||
|
||||
activities, err := dagInstance.GetActivities(filter)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, activities); err != nil {
|
||||
log.Error("Failed to write activities response", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
})
|
||||
|
||||
http.HandleFunc("/demo/stats", func(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := dagInstance.GetActivityLogger().(*dag.ActivityLogger).WriteJSON(w, stats); err != nil {
|
||||
log.Error("Failed to write stats response", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
})
|
||||
|
||||
log.Info("HTTP API endpoints configured")
|
||||
}
|
||||
|
||||
func demonstrateFeatures(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) {
|
||||
log.Info("Demonstrating enhanced DAG features...")
|
||||
|
||||
// 1. Process a successful task
|
||||
log.Info("Processing successful task...")
|
||||
processTask(ctx, dagInstance, map[string]interface{}{
|
||||
"id": "task-001",
|
||||
"data": "valid input data",
|
||||
"type": "success",
|
||||
}, log)
|
||||
|
||||
// 2. Process a task that will fail
|
||||
log.Info("Processing failing task...")
|
||||
processTask(ctx, dagInstance, map[string]interface{}{
|
||||
"id": "task-002",
|
||||
"data": nil, // This will cause processing issues
|
||||
"type": "failure",
|
||||
}, log)
|
||||
|
||||
// 3. Process with transaction
|
||||
log.Info("Processing with transaction...")
|
||||
processWithTransaction(ctx, dagInstance, map[string]interface{}{
|
||||
"id": "task-003",
|
||||
"data": "transaction data",
|
||||
"type": "transaction",
|
||||
}, log)
|
||||
|
||||
// 4. Demonstrate rate limiting
|
||||
log.Info("Demonstrating rate limiting...")
|
||||
demonstrateRateLimiting(ctx, dagInstance, log)
|
||||
|
||||
// 5. Show monitoring metrics
|
||||
time.Sleep(2 * time.Second) // Allow time for metrics to accumulate
|
||||
showMetrics(dagInstance, log)
|
||||
|
||||
// 6. Show activity logs
|
||||
showActivityLogs(dagInstance, log)
|
||||
}
|
||||
|
||||
func processTask(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) {
|
||||
// Add context information
|
||||
ctx = context.WithValue(ctx, "user_id", "demo-user")
|
||||
ctx = context.WithValue(ctx, "session_id", "demo-session")
|
||||
ctx = context.WithValue(ctx, "trace_id", mq.NewID())
|
||||
|
||||
result := dagInstance.Process(ctx, payload)
|
||||
if result.Error != nil {
|
||||
log.Error("Task processing failed",
|
||||
logger.Field{Key: "error", Value: result.Error.Error()},
|
||||
logger.Field{Key: "payload", Value: payload},
|
||||
)
|
||||
} else {
|
||||
log.Info("Task processed successfully",
|
||||
logger.Field{Key: "result_size", Value: len(result.Payload)},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func processWithTransaction(ctx context.Context, dagInstance *dag.DAG, payload map[string]interface{}, log logger.Logger) {
|
||||
taskID := fmt.Sprintf("tx-%s", mq.NewID())
|
||||
|
||||
// Begin transaction
|
||||
tx := dagInstance.BeginTransaction(taskID)
|
||||
if tx == nil {
|
||||
log.Error("Failed to begin transaction")
|
||||
return
|
||||
}
|
||||
|
||||
// Add transaction context
|
||||
ctx = context.WithValue(ctx, "transaction_id", tx.ID)
|
||||
ctx = context.WithValue(ctx, "task_id", taskID)
|
||||
|
||||
// Process the task
|
||||
result := dagInstance.Process(ctx, payload)
|
||||
|
||||
// Commit or rollback based on result
|
||||
if result.Error != nil {
|
||||
if err := dagInstance.RollbackTransaction(tx.ID); err != nil {
|
||||
log.Error("Failed to rollback transaction",
|
||||
logger.Field{Key: "tx_id", Value: tx.ID},
|
||||
logger.Field{Key: "error", Value: err.Error()},
|
||||
)
|
||||
} else {
|
||||
log.Info("Transaction rolled back",
|
||||
logger.Field{Key: "tx_id", Value: tx.ID},
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if err := dagInstance.CommitTransaction(tx.ID); err != nil {
|
||||
log.Error("Failed to commit transaction",
|
||||
logger.Field{Key: "tx_id", Value: tx.ID},
|
||||
logger.Field{Key: "error", Value: err.Error()},
|
||||
)
|
||||
} else {
|
||||
log.Info("Transaction committed",
|
||||
logger.Field{Key: "tx_id", Value: tx.ID},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func demonstrateRateLimiting(ctx context.Context, dagInstance *dag.DAG, log logger.Logger) {
|
||||
// Try to exceed rate limits
|
||||
for i := 0; i < 15; i++ {
|
||||
allowed := dagInstance.CheckRateLimit("validate")
|
||||
log.Info("Rate limit check",
|
||||
logger.Field{Key: "attempt", Value: i + 1},
|
||||
logger.Field{Key: "allowed", Value: allowed},
|
||||
)
|
||||
|
||||
if allowed {
|
||||
processTask(ctx, dagInstance, map[string]interface{}{
|
||||
"id": fmt.Sprintf("rate-test-%d", i),
|
||||
"data": "rate limiting test",
|
||||
}, log)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func showMetrics(dagInstance *dag.DAG, log logger.Logger) {
|
||||
metrics := dagInstance.GetMonitoringMetrics()
|
||||
if metrics != nil {
|
||||
log.Info("Current DAG Metrics",
|
||||
logger.Field{Key: "total_tasks", Value: metrics.TasksTotal},
|
||||
logger.Field{Key: "completed_tasks", Value: metrics.TasksCompleted},
|
||||
logger.Field{Key: "failed_tasks", Value: metrics.TasksFailed},
|
||||
logger.Field{Key: "tasks_in_progress", Value: metrics.TasksInProgress},
|
||||
logger.Field{Key: "avg_execution_time", Value: metrics.AverageExecutionTime.String()},
|
||||
)
|
||||
|
||||
// Show node-specific metrics
|
||||
for nodeID := range map[string]bool{"validate": true, "process": true, "enrich": true, "finalize": true} {
|
||||
if nodeStats := dagInstance.GetNodeStats(nodeID); nodeStats != nil {
|
||||
log.Info("Node Metrics",
|
||||
logger.Field{Key: "node_id", Value: nodeID},
|
||||
logger.Field{Key: "executions", Value: nodeStats.TotalExecutions},
|
||||
logger.Field{Key: "failures", Value: nodeStats.FailureCount},
|
||||
logger.Field{Key: "avg_duration", Value: nodeStats.AverageExecutionTime.String()},
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Warn("Monitoring metrics not available")
|
||||
}
|
||||
}
|
||||
|
||||
func showActivityLogs(dagInstance *dag.DAG, log logger.Logger) {
|
||||
// Get recent activities
|
||||
filter := dag.ActivityFilter{
|
||||
Limit: 10,
|
||||
SortBy: "timestamp",
|
||||
SortOrder: "desc",
|
||||
}
|
||||
|
||||
activities, err := dagInstance.GetActivities(filter)
|
||||
if err != nil {
|
||||
log.Error("Failed to get activities", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Recent Activities", logger.Field{Key: "count", Value: len(activities)})
|
||||
for _, activity := range activities {
|
||||
log.Info("Activity",
|
||||
logger.Field{Key: "id", Value: activity.ID},
|
||||
logger.Field{Key: "type", Value: string(activity.Type)},
|
||||
logger.Field{Key: "level", Value: string(activity.Level)},
|
||||
logger.Field{Key: "message", Value: activity.Message},
|
||||
logger.Field{Key: "task_id", Value: activity.TaskID},
|
||||
logger.Field{Key: "node_id", Value: activity.NodeID},
|
||||
)
|
||||
}
|
||||
|
||||
// Get activity statistics
|
||||
stats, err := dagInstance.GetActivityStats(dag.ActivityFilter{})
|
||||
if err != nil {
|
||||
log.Error("Failed to get activity stats", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Activity Statistics",
|
||||
logger.Field{Key: "total_activities", Value: stats.TotalActivities},
|
||||
logger.Field{Key: "success_rate", Value: fmt.Sprintf("%.2f%%", stats.SuccessRate*100)},
|
||||
logger.Field{Key: "failure_rate", Value: fmt.Sprintf("%.2f%%", stats.FailureRate*100)},
|
||||
logger.Field{Key: "avg_duration", Value: stats.AverageDuration.String()},
|
||||
)
|
||||
}
|
||||
|
||||
func waitForShutdown(ctx context.Context, cancel context.CancelFunc, dagInstance *dag.DAG, server *mq.Server, log logger.Logger) {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
log.Info("DAG system is running. Available endpoints:",
|
||||
logger.Field{Key: "workflow", Value: "http://localhost:8080/dag/"},
|
||||
logger.Field{Key: "process", Value: "http://localhost:8080/dag/process"},
|
||||
logger.Field{Key: "metrics", Value: "http://localhost:8080/api/dag/metrics"},
|
||||
logger.Field{Key: "health", Value: "http://localhost:8080/api/dag/health"},
|
||||
logger.Field{Key: "activities", Value: "http://localhost:8080/demo/activities"},
|
||||
logger.Field{Key: "stats", Value: "http://localhost:8080/demo/stats"},
|
||||
)
|
||||
|
||||
<-sigChan
|
||||
log.Info("Shutdown signal received, cleaning up...")
|
||||
|
||||
// Graceful shutdown
|
||||
cancel()
|
||||
|
||||
// Stop enhanced features
|
||||
dagInstance.StopEnhanced(ctx)
|
||||
|
||||
// Stop server
|
||||
if err := server.Stop(ctx); err != nil {
|
||||
log.Error("Error stopping server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
|
||||
log.Info("Shutdown complete")
|
||||
}
|
@@ -13,7 +13,7 @@ func main() {
|
||||
Payload: payload,
|
||||
}
|
||||
publisher := mq.NewPublisher("publish-1", mq.WithBrokerURL(":8081"))
|
||||
for i := 0; i < 10000000; i++ {
|
||||
for i := 0; i < 2; i++ {
|
||||
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
||||
err := publisher.Publish(context.Background(), task, "queue1")
|
||||
if err != nil {
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/oarkflow/json"
|
||||
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
|
||||
"github.com/oarkflow/jet"
|
||||
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
|
||||
func main() {
|
||||
flow := dag.NewDAG("SMS Sender", "sms-sender", func(taskID string, result mq.Result) {
|
||||
fmt.Printf("SMS workflow completed for task %s: %s\n", taskID, string(result.Payload))
|
||||
fmt.Printf("SMS workflow completed for task %s: %s\n", taskID, string(utils.RemoveRecursiveFromJSON(result.Payload, "html_content")))
|
||||
})
|
||||
|
||||
// Add SMS workflow nodes
|
||||
|
6
go.mod
6
go.mod
@@ -15,6 +15,8 @@ require (
|
||||
github.com/oarkflow/log v1.0.79
|
||||
github.com/oarkflow/xid v1.2.5
|
||||
github.com/prometheus/client_golang v1.21.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
|
||||
golang.org/x/time v0.11.0
|
||||
)
|
||||
@@ -23,13 +25,16 @@ require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/goccy/go-reflect v1.2.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.63.0 // indirect
|
||||
github.com/prometheus/procfs v0.16.0 // indirect
|
||||
@@ -38,4 +43,5 @@ require (
|
||||
github.com/valyala/fasthttp v1.59.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
12
go.sum
12
go.sum
@@ -4,6 +4,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/goccy/go-reflect v1.2.0 h1:O0T8rZCuNmGXewnATuKYnkL0xm6o8UNOJZd/gOkb9ms=
|
||||
@@ -18,6 +19,10 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
@@ -59,6 +64,8 @@ github.com/prometheus/procfs v0.16.0/go.mod h1:8veyXUu3nGP7oaCxhX6yeaM5u4stL2FeM
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
@@ -67,6 +74,8 @@ github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDp
|
||||
github.com/valyala/fasthttp v1.59.0/go.mod h1:GTxNb9Bc6r2a9D0TWNSPwDz78UxnTGBViY3xZNEqyYU=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -76,5 +85,8 @@ golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
31
mq.go
31
mq.go
@@ -688,7 +688,10 @@ func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
|
||||
msg, err := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create PUBLISH message: %w", err)
|
||||
}
|
||||
b.broadcastToConsumers(msg)
|
||||
return nil
|
||||
}
|
||||
@@ -698,7 +701,11 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||||
|
||||
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||
ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||
if err != nil {
|
||||
log.Printf("Error creating PUBLISH_ACK message: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := b.send(ctx, conn, ack); err != nil {
|
||||
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||
}
|
||||
@@ -713,7 +720,11 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
||||
|
||||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
|
||||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||
ack, err := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||
if err != nil {
|
||||
log.Printf("Error creating SUBSCRIBE_ACK message: %v\n", err)
|
||||
return
|
||||
}
|
||||
if err := b.send(ctx, conn, ack); err != nil {
|
||||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||
}
|
||||
@@ -894,8 +905,12 @@ func (b *Broker) handleConsumer(
|
||||
fn := func(queue *Queue) {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
if ok {
|
||||
ack := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||
err := b.send(ctx, con.conn, ack)
|
||||
ack, err := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||
if err != nil {
|
||||
log.Printf("Error creating message for consumer %s: %v", consumerID, err)
|
||||
return
|
||||
}
|
||||
err = b.send(ctx, con.conn, ack)
|
||||
if err == nil {
|
||||
con.state = state
|
||||
}
|
||||
@@ -921,7 +936,11 @@ func (b *Broker) UpdateConsumer(ctx context.Context, consumerID string, config D
|
||||
fn := func(queue *Queue) error {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
if ok {
|
||||
ack := codec.NewMessage(consts.CONSUMER_UPDATE, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||
ack, err := codec.NewMessage(consts.CONSUMER_UPDATE, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||
if err != nil {
|
||||
log.Printf("Error creating message for consumer %s: %v", consumerID, err)
|
||||
return err
|
||||
}
|
||||
return b.send(ctx, con.conn, ack)
|
||||
}
|
||||
return nil
|
||||
|
6
pool.go
6
pool.go
@@ -478,6 +478,7 @@ func (wp *Pool) processNextBatch() {
|
||||
if len(tasks) > 0 {
|
||||
for _, task := range tasks {
|
||||
if task != nil && !wp.gracefulShutdown {
|
||||
wp.taskCompletionNotifier.Add(1)
|
||||
wp.handleTask(task)
|
||||
}
|
||||
}
|
||||
@@ -487,6 +488,8 @@ func (wp *Pool) processNextBatch() {
|
||||
func (wp *Pool) handleTask(task *QueueTask) {
|
||||
if task == nil || task.payload == nil {
|
||||
wp.logger.Warn().Msg("Received nil task or payload")
|
||||
// Only call Done if Add was called (which is now only for actual tasks)
|
||||
wp.taskCompletionNotifier.Done()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -803,9 +806,6 @@ func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) er
|
||||
wp.taskAvailableCond.Signal()
|
||||
wp.taskAvailableCond.L.Unlock()
|
||||
|
||||
// Track pending task
|
||||
wp.taskCompletionNotifier.Add(1)
|
||||
|
||||
// Update metrics
|
||||
atomic.AddInt64(&wp.metrics.TotalScheduled, 1)
|
||||
|
||||
|
13
publisher.go
13
publisher.go
@@ -7,11 +7,11 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
||||
"github.com/oarkflow/json"
|
||||
|
||||
|
||||
"github.com/oarkflow/json/jsonparser"
|
||||
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
@@ -111,11 +111,14 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := codec.NewMessage(command, payload, queue, headers)
|
||||
msg, err := codec.NewMessage(command, payload, queue, headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
return p.waitForAck(ctx, conn)
|
||||
}
|
||||
|
||||
|
45
utils/json.go
Normal file
45
utils/json.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/oarkflow/json"
|
||||
"github.com/oarkflow/json/jsonparser"
|
||||
)
|
||||
|
||||
func RemoveFromJSONBye(jsonStr json.RawMessage, key ...string) json.RawMessage {
|
||||
return jsonparser.Delete(jsonStr, key...)
|
||||
}
|
||||
|
||||
func RemoveRecursiveFromJSON(jsonStr json.RawMessage, key ...string) json.RawMessage {
|
||||
var data any
|
||||
if err := json.Unmarshal(jsonStr, &data); err != nil {
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
for _, k := range key {
|
||||
data = removeKeyRecursive(data, k)
|
||||
}
|
||||
|
||||
result, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return jsonStr
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func removeKeyRecursive(data any, key string) any {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
delete(v, key)
|
||||
for k, val := range v {
|
||||
v[k] = removeKeyRecursive(val, key)
|
||||
}
|
||||
return v
|
||||
case []any:
|
||||
for i, item := range v {
|
||||
v[i] = removeKeyRecursive(item, key)
|
||||
}
|
||||
return v
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user