Files
mq/dag/v2/dag.go
2024-11-18 19:01:15 +05:45

319 lines
7.5 KiB
Go

package v2
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
type TaskStatus string
const (
StatusPending TaskStatus = "Pending"
StatusProcessing TaskStatus = "Processing"
StatusCompleted TaskStatus = "Completed"
StatusFailed TaskStatus = "Failed"
)
type Result struct {
Ctx context.Context
Data json.RawMessage
Error error
Status TaskStatus
}
type NodeType int
func (c NodeType) IsValid() bool { return c >= Process && c <= Page }
const (
Process NodeType = iota
Page
)
type Node struct {
Type NodeType
ID string
Handler func(ctx context.Context, payload json.RawMessage) Result
Edges []Edge
}
type EdgeType int
func (c EdgeType) IsValid() bool { return c >= Simple && c <= Iterator }
const (
Simple EdgeType = iota
Iterator
)
type Edge struct {
From *Node
To *Node
Type EdgeType
}
type DAG struct {
nodes storage.IMap[string, *Node]
taskManager storage.IMap[string, *TaskManager]
finalResult func(taskID string, result Result)
mu sync.Mutex
Error error
startNode string
}
func NewDAG(finalResultCallback func(taskID string, result Result)) *DAG {
return &DAG{
nodes: memory.New[string, *Node](),
taskManager: memory.New[string, *TaskManager](),
finalResult: finalResultCallback,
}
}
func (tm *DAG) parseInitialNode(ctx context.Context) (string, error) {
val := ctx.Value("initial_node")
initialNode, ok := val.(string)
if ok {
return initialNode, nil
}
if tm.startNode == "" {
firstNode := tm.findStartNode()
if firstNode != nil {
tm.startNode = firstNode.ID
}
}
if tm.startNode == "" {
return "", fmt.Errorf("initial node not found")
}
return tm.startNode, nil
}
func (tm *DAG) findStartNode() *Node {
incomingEdges := make(map[string]bool)
connectedNodes := make(map[string]bool)
for _, node := range tm.nodes.AsMap() {
for _, edge := range node.Edges {
if edge.Type.IsValid() {
connectedNodes[node.ID] = true
connectedNodes[edge.To.ID] = true
incomingEdges[edge.To.ID] = true
}
}
}
for nodeID, node := range tm.nodes.AsMap() {
if !incomingEdges[nodeID] && connectedNodes[nodeID] {
return node
}
}
return nil
}
func (tm *DAG) AddNode(nodeType NodeType, nodeID string, handler func(ctx context.Context, payload json.RawMessage) Result) *DAG {
if tm.Error != nil {
return tm
}
tm.nodes.Set(nodeID, &Node{ID: nodeID, Handler: handler, Type: nodeType})
return tm
}
func (tm *DAG) AddEdge(from string, targets ...string) *DAG {
if tm.Error != nil {
return tm
}
node, ok := tm.nodes.Get(from)
if !ok {
tm.Error = fmt.Errorf("node not found %s", from)
return tm
}
for _, target := range targets {
if targetNode, ok := tm.nodes.Get(target); ok {
edge := Edge{From: node, To: targetNode, Type: Simple}
node.Edges = append(node.Edges, edge)
}
}
return tm
}
func (tm *DAG) GetNextNodes(key string) ([]*Node, error) {
node, exists := tm.nodes.Get(key)
if !exists {
return nil, fmt.Errorf("node with key %s does not exist", key)
}
var successors []*Node
for _, edge := range node.Edges {
successors = append(successors, edge.To)
}
return successors, nil
}
func (tm *DAG) GetPreviousNodes(key string) ([]*Node, error) {
var predecessors []*Node
tm.nodes.ForEach(func(_ string, node *Node) bool {
for _, target := range node.Edges {
if target.To.ID == key {
predecessors = append(predecessors, node)
}
}
return true
})
return predecessors, nil
}
func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) Result {
var taskID string
userCtx := UserContext(ctx)
if val := userCtx.Get("task_id"); val != "" {
taskID = val
} else {
taskID = mq.NewID()
}
ctx = context.WithValue(ctx, "task_id", taskID)
resultCh := make(chan Result, 1)
manager := NewTaskManager(tm, resultCh)
tm.taskManager.Set(taskID, manager)
firstNode, err := tm.parseInitialNode(ctx)
if err != nil {
return Result{Error: err}
}
manager.ProcessTask(ctx, taskID, firstNode, payload)
return <-resultCh
}
func (tm *DAG) formHandler(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
http.ServeFile(w, r, "webroot/form.html")
} else if r.Method == "POST" {
r.ParseForm()
email := r.FormValue("email")
age := r.FormValue("age")
gender := r.FormValue("gender")
taskID := mq.NewID()
resultCh := make(chan Result, 1)
manager := NewTaskManager(tm, resultCh)
tm.taskManager.Set(taskID, manager)
payload := fmt.Sprintf(`{"email": "%s", "age": "%s", "gender": "%s"}`, email, age, gender)
manager.ProcessTask(r.Context(), taskID, "NodeA", json.RawMessage(payload))
http.Redirect(w, r, "/result?taskID="+taskID, http.StatusFound)
}
}
func (tm *DAG) resultHandler(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "webroot/result.html")
}
func (tm *DAG) taskStatusHandler(w http.ResponseWriter, r *http.Request) {
taskID := r.URL.Query().Get("taskID")
if taskID == "" {
http.Error(w, `{"message": "taskID is missing"}`, http.StatusBadRequest)
return
}
manager, ok := tm.taskManager.Get(taskID)
if !ok {
http.Error(w, `{"message": "Invalid TaskID"}`, http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(manager.taskStates)
}
func (tm *DAG) Start(addr string) {
http.HandleFunc("/", func(w http.ResponseWriter, request *http.Request) {
ctx, data, err := parse(request)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
result := tm.ProcessTask(ctx, data)
if contentType, ok := result.Ctx.Value(consts.ContentType).(string); ok && contentType == consts.TypeHtml {
w.Header().Set(consts.ContentType, consts.TypeHtml)
w.Write(result.Data)
}
})
http.HandleFunc("/form", tm.formHandler)
http.HandleFunc("/result", tm.resultHandler)
http.HandleFunc("/task-result", tm.taskStatusHandler)
http.ListenAndServe(addr, nil)
}
type Context struct {
Query map[string]any
}
func (ctx *Context) Get(key string) string {
if val, ok := ctx.Query[key]; ok {
switch val := val.(type) {
case []string:
return val[0]
case string:
return val
}
}
return ""
}
func parse(r *http.Request) (context.Context, []byte, error) {
ctx := r.Context()
body, err := io.ReadAll(r.Body)
if err != nil {
return ctx, nil, err
}
defer r.Body.Close()
userContext := &Context{Query: make(map[string]any)}
result := make(map[string]any)
queryParams := r.URL.Query()
for key, values := range queryParams {
if len(values) > 1 {
userContext.Query[key] = values // Handle multiple values
} else {
userContext.Query[key] = values[0] // Single value
}
}
ctx = context.WithValue(ctx, "UserContext", userContext)
contentType := r.Header.Get("Content-Type")
switch {
case contentType == "application/json":
if body == nil {
return ctx, nil, nil
}
if err := json.Unmarshal(body, &result); err != nil {
return ctx, nil, err
}
case contentType == "application/x-www-form-urlencoded":
if err := r.ParseForm(); err != nil {
return ctx, nil, err
}
result = make(map[string]any)
for key, values := range r.Form {
if len(values) > 1 {
result[key] = values
} else {
result[key] = values[0]
}
}
default:
return ctx, nil, nil
}
bt, err := json.Marshal(result)
if err != nil {
return ctx, nil, err
}
return ctx, bt, err
}
func UserContext(ctx context.Context) *Context {
if userContext, ok := ctx.Value("UserContext").(*Context); ok {
return userContext
}
return &Context{Query: make(map[string]any)}
}