mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 16:06:55 +08:00
564 lines
14 KiB
Go
564 lines
14 KiB
Go
package dag
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/oarkflow/mq"
|
|
"github.com/oarkflow/mq/logger"
|
|
)
|
|
|
|
// BatchProcessor handles batch processing of tasks
|
|
type BatchProcessor struct {
|
|
dag *DAG
|
|
batchSize int
|
|
batchTimeout time.Duration
|
|
buffer []*mq.Task
|
|
bufferMu sync.Mutex
|
|
flushTimer *time.Timer
|
|
logger logger.Logger
|
|
processFunc func([]*mq.Task) error
|
|
stopCh chan struct{}
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// NewBatchProcessor creates a new batch processor
|
|
func NewBatchProcessor(dag *DAG, batchSize int, batchTimeout time.Duration, logger logger.Logger) *BatchProcessor {
|
|
return &BatchProcessor{
|
|
dag: dag,
|
|
batchSize: batchSize,
|
|
batchTimeout: batchTimeout,
|
|
buffer: make([]*mq.Task, 0, batchSize),
|
|
logger: logger,
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// SetProcessFunc sets the function to process batches
|
|
func (bp *BatchProcessor) SetProcessFunc(fn func([]*mq.Task) error) {
|
|
bp.processFunc = fn
|
|
}
|
|
|
|
// AddTask adds a task to the batch
|
|
func (bp *BatchProcessor) AddTask(task *mq.Task) error {
|
|
bp.bufferMu.Lock()
|
|
defer bp.bufferMu.Unlock()
|
|
|
|
bp.buffer = append(bp.buffer, task)
|
|
|
|
// Reset timer
|
|
if bp.flushTimer != nil {
|
|
bp.flushTimer.Stop()
|
|
}
|
|
bp.flushTimer = time.AfterFunc(bp.batchTimeout, bp.flushBatch)
|
|
|
|
// Check if batch is full
|
|
if len(bp.buffer) >= bp.batchSize {
|
|
bp.flushTimer.Stop()
|
|
go bp.flushBatch()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// flushBatch processes the current batch
|
|
func (bp *BatchProcessor) flushBatch() {
|
|
bp.bufferMu.Lock()
|
|
if len(bp.buffer) == 0 {
|
|
bp.bufferMu.Unlock()
|
|
return
|
|
}
|
|
|
|
tasks := make([]*mq.Task, len(bp.buffer))
|
|
copy(tasks, bp.buffer)
|
|
bp.buffer = bp.buffer[:0] // Clear buffer
|
|
bp.bufferMu.Unlock()
|
|
|
|
if bp.processFunc != nil {
|
|
if err := bp.processFunc(tasks); err != nil {
|
|
bp.logger.Error("Batch processing failed",
|
|
logger.Field{Key: "error", Value: err.Error()},
|
|
logger.Field{Key: "batch_size", Value: len(tasks)},
|
|
)
|
|
} else {
|
|
bp.logger.Info("Batch processed successfully",
|
|
logger.Field{Key: "batch_size", Value: len(tasks)},
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stop stops the batch processor
|
|
func (bp *BatchProcessor) Stop() {
|
|
close(bp.stopCh)
|
|
bp.wg.Wait()
|
|
|
|
// Flush remaining tasks
|
|
bp.flushBatch()
|
|
}
|
|
|
|
// TransactionManager handles transaction-like operations for DAG execution
|
|
type TransactionManager struct {
|
|
dag *DAG
|
|
transactions map[string]*Transaction
|
|
savePoints map[string][]SavePoint
|
|
mu sync.RWMutex
|
|
logger logger.Logger
|
|
}
|
|
|
|
// Transaction represents a transactional DAG execution
|
|
type Transaction struct {
|
|
ID string `json:"id"`
|
|
TaskID string `json:"task_id"`
|
|
Status TransactionStatus `json:"status"`
|
|
StartTime time.Time `json:"start_time"`
|
|
EndTime time.Time `json:"end_time,omitempty"`
|
|
Operations []TransactionOperation `json:"operations"`
|
|
SavePoints []SavePoint `json:"save_points"`
|
|
Metadata map[string]any `json:"metadata,omitempty"`
|
|
}
|
|
|
|
// TransactionStatus represents the status of a transaction
|
|
type TransactionStatus string
|
|
|
|
const (
|
|
TransactionStatusStarted TransactionStatus = "started"
|
|
TransactionStatusCommitted TransactionStatus = "committed"
|
|
TransactionStatusRolledBack TransactionStatus = "rolled_back"
|
|
TransactionStatusFailed TransactionStatus = "failed"
|
|
)
|
|
|
|
// TransactionOperation represents an operation within a transaction
|
|
type TransactionOperation struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
NodeID string `json:"node_id"`
|
|
Data map[string]any `json:"data"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
RollbackHandler RollbackHandler `json:"-"`
|
|
}
|
|
|
|
// SavePoint represents a save point in a transaction
|
|
type SavePoint struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
State map[string]any `json:"state"`
|
|
}
|
|
|
|
// RollbackHandler defines how to rollback operations
|
|
type RollbackHandler interface {
|
|
Rollback(operation TransactionOperation) error
|
|
}
|
|
|
|
// NewTransactionManager creates a new transaction manager
|
|
func NewTransactionManager(dag *DAG, logger logger.Logger) *TransactionManager {
|
|
return &TransactionManager{
|
|
dag: dag,
|
|
transactions: make(map[string]*Transaction),
|
|
savePoints: make(map[string][]SavePoint),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// BeginTransaction starts a new transaction
|
|
func (tm *TransactionManager) BeginTransaction(taskID string) *Transaction {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
|
|
tx := &Transaction{
|
|
ID: mq.NewID(),
|
|
TaskID: taskID,
|
|
Status: TransactionStatusStarted,
|
|
StartTime: time.Now(),
|
|
Operations: make([]TransactionOperation, 0),
|
|
SavePoints: make([]SavePoint, 0),
|
|
Metadata: make(map[string]any),
|
|
}
|
|
|
|
tm.transactions[tx.ID] = tx
|
|
|
|
tm.logger.Info("Transaction started",
|
|
logger.Field{Key: "transaction_id", Value: tx.ID},
|
|
logger.Field{Key: "task_id", Value: taskID},
|
|
)
|
|
|
|
return tx
|
|
}
|
|
|
|
// AddOperation adds an operation to a transaction
|
|
func (tm *TransactionManager) AddOperation(txID string, operation TransactionOperation) error {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
|
|
tx, exists := tm.transactions[txID]
|
|
if !exists {
|
|
return fmt.Errorf("transaction %s not found", txID)
|
|
}
|
|
|
|
if tx.Status != TransactionStatusStarted {
|
|
return fmt.Errorf("transaction %s is not active", txID)
|
|
}
|
|
|
|
operation.ID = mq.NewID()
|
|
operation.Timestamp = time.Now()
|
|
tx.Operations = append(tx.Operations, operation)
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddSavePoint adds a save point to the transaction
|
|
func (tm *TransactionManager) AddSavePoint(txID, name string, state map[string]any) error {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
|
|
tx, exists := tm.transactions[txID]
|
|
if !exists {
|
|
return fmt.Errorf("transaction %s not found", txID)
|
|
}
|
|
|
|
savePoint := SavePoint{
|
|
ID: mq.NewID(),
|
|
Name: name,
|
|
Timestamp: time.Now(),
|
|
State: state,
|
|
}
|
|
|
|
tx.SavePoints = append(tx.SavePoints, savePoint)
|
|
tm.savePoints[txID] = tx.SavePoints
|
|
|
|
tm.logger.Info("Save point created",
|
|
logger.Field{Key: "transaction_id", Value: txID},
|
|
logger.Field{Key: "save_point_name", Value: name},
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// CommitTransaction commits a transaction
|
|
func (tm *TransactionManager) CommitTransaction(txID string) error {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
|
|
tx, exists := tm.transactions[txID]
|
|
if !exists {
|
|
return fmt.Errorf("transaction %s not found", txID)
|
|
}
|
|
|
|
if tx.Status != TransactionStatusStarted {
|
|
return fmt.Errorf("transaction %s is not active", txID)
|
|
}
|
|
|
|
tx.Status = TransactionStatusCommitted
|
|
tx.EndTime = time.Now()
|
|
|
|
if tm.dag.debug {
|
|
tm.logger.Info("Transaction committed",
|
|
logger.Field{Key: "transaction_id", Value: txID},
|
|
logger.Field{Key: "operations_count", Value: len(tx.Operations)},
|
|
)
|
|
}
|
|
|
|
// Clean up save points
|
|
delete(tm.savePoints, txID)
|
|
|
|
return nil
|
|
}
|
|
|
|
// RollbackTransaction rolls back a transaction
|
|
func (tm *TransactionManager) RollbackTransaction(txID string) error {
|
|
tm.mu.Lock()
|
|
defer tm.mu.Unlock()
|
|
|
|
tx, exists := tm.transactions[txID]
|
|
if !exists {
|
|
return fmt.Errorf("transaction %s not found", txID)
|
|
}
|
|
|
|
if tx.Status != TransactionStatusStarted {
|
|
return fmt.Errorf("transaction %s is not active", txID)
|
|
}
|
|
|
|
// Rollback operations in reverse order
|
|
for i := len(tx.Operations) - 1; i >= 0; i-- {
|
|
operation := tx.Operations[i]
|
|
if operation.RollbackHandler != nil {
|
|
if err := operation.RollbackHandler.Rollback(operation); err != nil {
|
|
tm.logger.Error("Failed to rollback operation",
|
|
logger.Field{Key: "transaction_id", Value: txID},
|
|
logger.Field{Key: "operation_id", Value: operation.ID},
|
|
logger.Field{Key: "error", Value: err.Error()},
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
tx.Status = TransactionStatusRolledBack
|
|
tx.EndTime = time.Now()
|
|
|
|
tm.logger.Info("Transaction rolled back",
|
|
logger.Field{Key: "transaction_id", Value: txID},
|
|
logger.Field{Key: "operations_count", Value: len(tx.Operations)},
|
|
)
|
|
|
|
// Clean up save points
|
|
delete(tm.savePoints, txID)
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetTransaction retrieves a transaction by ID
|
|
func (tm *TransactionManager) GetTransaction(txID string) (*Transaction, error) {
|
|
tm.mu.RLock()
|
|
defer tm.mu.RUnlock()
|
|
|
|
tx, exists := tm.transactions[txID]
|
|
if !exists {
|
|
return nil, fmt.Errorf("transaction %s not found", txID)
|
|
}
|
|
|
|
// Return a copy
|
|
txCopy := *tx
|
|
return &txCopy, nil
|
|
}
|
|
|
|
// CleanupManager handles cleanup of completed tasks and resources
|
|
type CleanupManager struct {
|
|
dag *DAG
|
|
cleanupInterval time.Duration
|
|
retentionPeriod time.Duration
|
|
maxEntries int
|
|
logger logger.Logger
|
|
stopCh chan struct{}
|
|
running bool
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewCleanupManager creates a new cleanup manager
|
|
func NewCleanupManager(dag *DAG, cleanupInterval, retentionPeriod time.Duration, maxEntries int, logger logger.Logger) *CleanupManager {
|
|
return &CleanupManager{
|
|
dag: dag,
|
|
cleanupInterval: cleanupInterval,
|
|
retentionPeriod: retentionPeriod,
|
|
maxEntries: maxEntries,
|
|
logger: logger,
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// Start begins the cleanup routine
|
|
func (cm *CleanupManager) Start(ctx context.Context) {
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
|
|
if cm.running {
|
|
return
|
|
}
|
|
|
|
cm.running = true
|
|
go cm.cleanupRoutine(ctx)
|
|
|
|
cm.logger.Info("Cleanup manager started")
|
|
}
|
|
|
|
// Stop stops the cleanup routine
|
|
func (cm *CleanupManager) Stop() {
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
|
|
if !cm.running {
|
|
return
|
|
}
|
|
|
|
cm.running = false
|
|
close(cm.stopCh)
|
|
|
|
cm.logger.Info("Cleanup manager stopped")
|
|
}
|
|
|
|
// cleanupRoutine performs periodic cleanup
|
|
func (cm *CleanupManager) cleanupRoutine(ctx context.Context) {
|
|
ticker := time.NewTicker(cm.cleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-cm.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
cm.performCleanup()
|
|
}
|
|
}
|
|
}
|
|
|
|
// performCleanup cleans up old tasks and resources
|
|
func (cm *CleanupManager) performCleanup() {
|
|
cutoff := time.Now().Add(-cm.retentionPeriod)
|
|
|
|
// Clean up old task managers
|
|
var toDelete []string
|
|
cm.dag.taskManager.ForEach(func(taskID string, tm *TaskManager) bool {
|
|
if tm.createdAt.Before(cutoff) {
|
|
toDelete = append(toDelete, taskID)
|
|
}
|
|
return true
|
|
})
|
|
|
|
for _, taskID := range toDelete {
|
|
if tm, exists := cm.dag.taskManager.Get(taskID); exists {
|
|
tm.Stop()
|
|
cm.dag.taskManager.Del(taskID)
|
|
}
|
|
}
|
|
|
|
// Clean up circuit breakers for removed nodes
|
|
cm.dag.circuitBreakersMu.Lock()
|
|
for nodeID := range cm.dag.circuitBreakers {
|
|
if _, exists := cm.dag.nodes.Get(nodeID); !exists {
|
|
delete(cm.dag.circuitBreakers, nodeID)
|
|
}
|
|
}
|
|
cm.dag.circuitBreakersMu.Unlock()
|
|
|
|
if len(toDelete) > 0 {
|
|
cm.logger.Info("Cleanup completed",
|
|
logger.Field{Key: "cleaned_tasks", Value: len(toDelete)},
|
|
)
|
|
}
|
|
}
|
|
|
|
// WebhookManager handles webhook notifications
|
|
type WebhookManager struct {
|
|
webhooks map[string][]WebhookConfig
|
|
httpClient HTTPClient
|
|
logger logger.Logger
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// WebhookConfig defines webhook configuration
|
|
type WebhookConfig struct {
|
|
URL string `json:"url"`
|
|
Headers map[string]string `json:"headers"`
|
|
Method string `json:"method"`
|
|
RetryCount int `json:"retry_count"`
|
|
Timeout time.Duration `json:"timeout"`
|
|
Events []string `json:"events"`
|
|
}
|
|
|
|
// HTTPClient interface for HTTP requests
|
|
type HTTPClient interface {
|
|
Post(url string, contentType string, body []byte, headers map[string]string) error
|
|
}
|
|
|
|
// WebhookEvent represents an event to send via webhook
|
|
type WebhookEvent struct {
|
|
Type string `json:"type"`
|
|
TaskID string `json:"task_id"`
|
|
NodeID string `json:"node_id,omitempty"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
Data map[string]any `json:"data"`
|
|
}
|
|
|
|
// NewWebhookManager creates a new webhook manager
|
|
func NewWebhookManager(httpClient HTTPClient, logger logger.Logger) *WebhookManager {
|
|
return &WebhookManager{
|
|
webhooks: make(map[string][]WebhookConfig),
|
|
httpClient: httpClient,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// AddWebhook adds a webhook configuration
|
|
func (wm *WebhookManager) AddWebhook(event string, config WebhookConfig) {
|
|
wm.mu.Lock()
|
|
defer wm.mu.Unlock()
|
|
|
|
if config.Method == "" {
|
|
config.Method = "POST"
|
|
}
|
|
if config.RetryCount == 0 {
|
|
config.RetryCount = 3
|
|
}
|
|
if config.Timeout == 0 {
|
|
config.Timeout = 30 * time.Second
|
|
}
|
|
|
|
wm.webhooks[event] = append(wm.webhooks[event], config)
|
|
|
|
wm.logger.Info("Webhook added",
|
|
logger.Field{Key: "event", Value: event},
|
|
logger.Field{Key: "url", Value: config.URL},
|
|
)
|
|
}
|
|
|
|
// TriggerWebhook sends webhook notifications for an event
|
|
func (wm *WebhookManager) TriggerWebhook(event WebhookEvent) {
|
|
wm.mu.RLock()
|
|
configs, exists := wm.webhooks[event.Type]
|
|
wm.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
for _, config := range configs {
|
|
// Check if this webhook should handle this event
|
|
if len(config.Events) > 0 {
|
|
found := false
|
|
for _, eventType := range config.Events {
|
|
if eventType == event.Type {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
continue
|
|
}
|
|
}
|
|
|
|
go wm.sendWebhook(config, event)
|
|
}
|
|
}
|
|
|
|
// sendWebhook sends a single webhook with retry logic
|
|
func (wm *WebhookManager) sendWebhook(config WebhookConfig, event WebhookEvent) {
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
wm.logger.Error("Failed to marshal webhook payload",
|
|
logger.Field{Key: "error", Value: err.Error()},
|
|
)
|
|
return
|
|
}
|
|
|
|
for attempt := 0; attempt < config.RetryCount; attempt++ {
|
|
err := wm.httpClient.Post(config.URL, "application/json", payload, config.Headers)
|
|
if err == nil {
|
|
wm.logger.Info("Webhook sent successfully",
|
|
logger.Field{Key: "url", Value: config.URL},
|
|
logger.Field{Key: "event_type", Value: event.Type},
|
|
)
|
|
return
|
|
}
|
|
|
|
wm.logger.Warn("Webhook delivery failed",
|
|
logger.Field{Key: "url", Value: config.URL},
|
|
logger.Field{Key: "attempt", Value: attempt + 1},
|
|
logger.Field{Key: "error", Value: err.Error()},
|
|
)
|
|
|
|
if attempt < config.RetryCount-1 {
|
|
time.Sleep(time.Duration(attempt+1) * time.Second)
|
|
}
|
|
}
|
|
|
|
wm.logger.Error("Webhook delivery failed after all retries",
|
|
logger.Field{Key: "url", Value: config.URL},
|
|
logger.Field{Key: "event_type", Value: event.Type},
|
|
)
|
|
}
|