diff --git a/dag/v2/dag.go b/dag/v2/dag.go index caa7c29..865f70b 100644 --- a/dag/v2/dag.go +++ b/dag/v2/dag.go @@ -4,10 +4,12 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "sync" "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" "github.com/oarkflow/mq/storage" "github.com/oarkflow/mq/storage/memory" ) @@ -76,10 +78,6 @@ func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG { } } -func (tm *DAG) Validate(ctx context.Context) (string, error) { - return tm.parseInitialNode(ctx) -} - func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) { val := ctx.Value("initial_node") initialNode, ok := val.(string) @@ -170,6 +168,26 @@ func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) { return predecessors, nil } +func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) Result { + var taskID string + userCtx := UserContext(ctx) + if val := userCtx.Get("task_id"); val != "" { + taskID = val + } else { + taskID = mq.NewID() + } + ctx = context.WithValue(ctx, "task_id", taskID) + resultCh := make(chan Result, 1) + manager := NewTaskManager(tm, resultCh) + tm.taskManager.Set(taskID, manager) + firstNode, err := tm.parseInitialNode(ctx) + if err != nil { + return Result{Error: err} + } + manager.ProcessTask(ctx, taskID, firstNode, payload) + return <-resultCh +} + func (tm *DAG) formHandler(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { http.ServeFile(w, r, "webroot/form.html") @@ -179,7 +197,8 @@ func (tm *DAG) formHandler(w http.ResponseWriter, r *http.Request) { age := r.FormValue("age") gender := r.FormValue("gender") taskID := mq.NewID() - manager := NewTaskManager(tm) + resultCh := make(chan Result, 1) + manager := NewTaskManager(tm, resultCh) tm.taskManager.Set(taskID, manager) payload := fmt.Sprintf(`{"email": "%s", "age": "%s", "gender": "%s"}`, email, age, gender) manager.ProcessTask(r.Context(), taskID, "NodeA", json.RawMessage(payload)) @@ -208,19 +227,92 @@ func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) { func (tm *DAG) Start(addr string) { http.HandleFunc("/", func(w http.ResponseWriter, request *http.Request) { - firstNode, err := tm.Validate(request.Context()) + ctx, data, err := parse(request) if err != nil { - http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest) + http.Error(w, err.Error(), http.StatusNotFound) return } - node, _ := tm.nodes.Get(firstNode) - if node.Type == Page { - + result := tm.ProcessTask(ctx, data) + if contentType, ok := result.Ctx.Value(consts.ContentType).(string); ok && contentType == consts.TypeHtml { + w.Header().Set(consts.ContentType, consts.TypeHtml) + w.Write(result.Data) } - w.Write([]byte(firstNode)) }) http.HandleFunc("/form", tm.formHandler) http.HandleFunc("/result", tm.resultHandler) http.HandleFunc("/task-result", tm.taskStatusHandler) http.ListenAndServe(addr, nil) } + +type Context struct { + Query map[string]any +} + +func (ctx *Context) Get(key string) string { + if val, ok := ctx.Query[key]; ok { + switch val := val.(type) { + case []string: + return val[0] + case string: + return val + } + } + return "" +} + +func parse(r *http.Request) (context.Context, []byte, error) { + ctx := r.Context() + body, err := io.ReadAll(r.Body) + if err != nil { + return ctx, nil, err + } + defer r.Body.Close() + userContext := &Context{Query: make(map[string]any)} + result := make(map[string]any) + queryParams := r.URL.Query() + for key, values := range queryParams { + if len(values) > 1 { + userContext.Query[key] = values // Handle multiple values + } else { + userContext.Query[key] = values[0] // Single value + } + } + ctx = context.WithValue(ctx, "UserContext", userContext) + contentType := r.Header.Get("Content-Type") + switch { + case contentType == "application/json": + if body == nil { + return ctx, nil, nil + } + if err := json.Unmarshal(body, &result); err != nil { + return ctx, nil, err + } + + case contentType == "application/x-www-form-urlencoded": + if err := r.ParseForm(); err != nil { + return ctx, nil, err + } + result = make(map[string]any) + for key, values := range r.Form { + if len(values) > 1 { + result[key] = values + } else { + result[key] = values[0] + } + } + default: + return ctx, nil, nil + } + bt, err := json.Marshal(result) + if err != nil { + return ctx, nil, err + } + return ctx, bt, err +} + +func UserContext(ctx context.Context) *Context { + if userContext, ok := ctx.Value("UserContext").(*Context); ok { + return userContext + } + return &Context{Query: make(map[string]any)} +} diff --git a/dag/v2/task_manager.go b/dag/v2/task_manager.go index 41c72c7..c647895 100644 --- a/dag/v2/task_manager.go +++ b/dag/v2/task_manager.go @@ -28,24 +28,36 @@ type nodeResult struct { type TaskManager struct { taskStates map[string]*TaskState + currentNode string dag *DAG mu sync.RWMutex - taskQueue chan taskExecution + taskQueue chan *Task resultQueue chan nodeResult + resultCh chan Result } -type taskExecution struct { +type Task struct { ctx context.Context taskID string nodeID string payload json.RawMessage } -func NewTaskManager(dag *DAG) *TaskManager { +func NewTask(ctx context.Context, taskID, nodeID string, payload json.RawMessage) *Task { + return &Task{ + ctx: ctx, + taskID: taskID, + nodeID: nodeID, + payload: payload, + } +} + +func NewTaskManager(dag *DAG, resultCh chan Result) *TaskManager { tm := &TaskManager{ taskStates: make(map[string]*TaskState), - taskQueue: make(chan taskExecution, 100), + taskQueue: make(chan *Task, 100), resultQueue: make(chan nodeResult, 100), + resultCh: resultCh, dag: dag, } go tm.Run() @@ -57,7 +69,7 @@ func (tm *TaskManager) ProcessTask(ctx context.Context, taskID, startNode string tm.mu.Lock() tm.taskStates[startNode] = newTaskState(startNode) tm.mu.Unlock() - tm.taskQueue <- taskExecution{taskID: taskID, nodeID: startNode, payload: payload, ctx: ctx} + tm.taskQueue <- NewTask(ctx, taskID, startNode, payload) } func newTaskState(nodeID string) *TaskState { @@ -77,7 +89,7 @@ func (tm *TaskManager) Run() { }() } -func (tm *TaskManager) processNode(exec taskExecution) { +func (tm *TaskManager) processNode(exec *Task) { node, exists := tm.dag.nodes.Get(exec.nodeID) if !exists { fmt.Printf("Node %s does not exist\n", exec.nodeID) @@ -92,13 +104,20 @@ func (tm *TaskManager) processNode(exec taskExecution) { } state.Status = StatusProcessing state.UpdatedAt = time.Now() - result := node.Handler(context.Background(), exec.payload) + tm.currentNode = exec.nodeID + result := node.Handler(exec.ctx, exec.payload) state.UpdatedAt = time.Now() state.Result = result - state.Status = result.Status - if result.Status == StatusFailed { - fmt.Printf("Task %s failed at node %s: %v\n", exec.taskID, exec.nodeID, result.Error) - tm.processFinalResult(exec.taskID, state) + if result.Ctx == nil { + result.Ctx = exec.ctx + } + if result.Error != nil { + state.Status = StatusFailed + } else { + state.Status = StatusCompleted + } + if node.Type == Page { + tm.resultCh <- result return } tm.resultQueue <- nodeResult{taskID: exec.taskID, nodeID: exec.nodeID, result: result, ctx: exec.ctx} @@ -117,16 +136,7 @@ func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) { if !ok { return } - if len(node.Edges) > 0 { - for _, edge := range node.Edges { - tm.mu.Lock() - if _, exists := tm.taskStates[edge.To.ID]; !exists { - tm.taskStates[edge.To.ID] = newTaskState(edge.To.ID) - } - tm.mu.Unlock() - tm.taskQueue <- taskExecution{taskID: nodeResult.taskID, nodeID: edge.To.ID, payload: nodeResult.result.Data, ctx: nodeResult.ctx} - } - } else { + if nodeResult.result.Error != nil || len(node.Edges) == 0 { parentNodes, err := tm.dag.GetPreviousNodes(nodeResult.nodeID) if err == nil { for _, parentNode := range parentNodes { @@ -145,6 +155,15 @@ func (tm *TaskManager) onNodeCompleted(nodeResult nodeResult) { } } } + return + } + for _, edge := range node.Edges { + tm.mu.Lock() + if _, exists := tm.taskStates[edge.To.ID]; !exists { + tm.taskStates[edge.To.ID] = newTaskState(edge.To.ID) + } + tm.mu.Unlock() + tm.taskQueue <- NewTask(nodeResult.ctx, nodeResult.taskID, edge.To.ID, nodeResult.result.Data) } } diff --git a/examples/v2.go b/examples/v2.go index 7fea36c..80aa1f0 100644 --- a/examples/v2.go +++ b/examples/v2.go @@ -7,11 +7,12 @@ import ( "github.com/oarkflow/jet" + "github.com/oarkflow/mq/consts" v2 "github.com/oarkflow/mq/dag/v2" ) func Form(ctx context.Context, payload json.RawMessage) v2.Result { - template := []byte(` + template := `
@@ -21,7 +22,7 @@ func Form(ctx context.Context, payload json.RawMessage) v2.Result {