diff --git a/dag/api.go b/dag/api.go new file mode 100644 index 0000000..130e5d7 --- /dev/null +++ b/dag/api.go @@ -0,0 +1,157 @@ +package dag + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/consts" +) + +type Request struct { + Payload json.RawMessage `json:"payload"` + Interval time.Duration `json:"interval"` + Schedule bool `json:"schedule"` + Overlap bool `json:"overlap"` + Recurring bool `json:"recurring"` +} + +func (tm *DAG) Handlers() { + http.HandleFunc("POST /request", tm.Request) + http.HandleFunc("POST /publish", tm.Publish) + http.HandleFunc("POST /schedule", tm.Schedule) + http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + tm.PauseConsumer(request.Context(), id) + } + }) + http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { + id := request.PathValue("id") + if id != "" { + tm.ResumeConsumer(request.Context(), id) + } + }) + http.HandleFunc("/pause", func(w http.ResponseWriter, request *http.Request) { + err := tm.Pause(request.Context()) + if err != nil { + http.Error(w, "Failed to pause", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "paused"}) + }) + http.HandleFunc("/resume", func(w http.ResponseWriter, request *http.Request) { + err := tm.Resume(request.Context()) + if err != nil { + http.Error(w, "Failed to resume", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "resumed"}) + }) + http.HandleFunc("/stop", func(w http.ResponseWriter, request *http.Request) { + err := tm.Stop(request.Context()) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "stopped"}) + }) + http.HandleFunc("/close", func(w http.ResponseWriter, request *http.Request) { + err := tm.Close() + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + json.NewEncoder(w).Encode(map[string]string{"status": "closed"}) + }) + http.HandleFunc("/dot", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprintln(w, tm.ExportDOT()) + }) + http.HandleFunc("/ui", func(w http.ResponseWriter, r *http.Request) { + image := fmt.Sprintf("%s.svg", mq.NewID()) + err := tm.SaveSVG(image) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + defer os.Remove(image) + svgBytes, err := os.ReadFile(image) + if err != nil { + http.Error(w, "Could not read SVG file", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "image/svg+xml") + if _, err := w.Write(svgBytes); err != nil { + http.Error(w, "Could not write SVG response", http.StatusInternalServerError) + return + } + }) +} + +func (tm *DAG) request(w http.ResponseWriter, r *http.Request, async bool) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) + return + } + var request Request + if r.Body != nil { + defer r.Body.Close() + var err error + payload, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + err = json.Unmarshal(payload, &request) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + } else { + http.Error(w, "Empty request body", http.StatusBadRequest) + return + } + ctx := r.Context() + if async { + ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) + } + var opts []mq.SchedulerOption + if request.Interval > 0 { + opts = append(opts, mq.WithInterval(request.Interval)) + } + if request.Overlap { + opts = append(opts, mq.WithOverlap()) + } + if request.Recurring { + opts = append(opts, mq.WithRecurring()) + } + var rs mq.Result + if request.Schedule { + rs = tm.ScheduleTask(ctx, request.Payload, opts...) + } else { + rs = tm.Process(ctx, request.Payload) + } + if rs.Error != nil { + http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(rs) +} + +func (tm *DAG) Request(w http.ResponseWriter, r *http.Request) { + tm.request(w, r, true) +} + +func (tm *DAG) Publish(w http.ResponseWriter, r *http.Request) { + tm.request(w, r, false) +} + +func (tm *DAG) Schedule(w http.ResponseWriter, r *http.Request) { + tm.request(w, r, false) +} diff --git a/dag/dag.go b/dag/dag.go index cac9a67..9b6c900 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -2,9 +2,7 @@ package dag import ( "context" - "encoding/json" "fmt" - "io" "log" "net/http" "sync" @@ -81,10 +79,9 @@ func (tm *DAG) GetType() string { func (tm *DAG) listenForTaskCleanup() { for taskID := range tm.taskCleanupCh { - tm.mu.Lock() - delete(tm.taskContext, taskID) - tm.mu.Unlock() - log.Printf("DAG - Task %s cleaned up", taskID) + if tm.server.Options().CleanTaskOnComplete() { + tm.taskCleanup(taskID) + } } } @@ -182,35 +179,6 @@ func (tm *DAG) GetStartNode() string { return tm.startNode } -func (tm *DAG) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) - return - } - var payload []byte - if r.Body != nil { - defer r.Body.Close() - var err error - payload, err = io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - } else { - http.Error(w, "Empty request body", http.StatusBadRequest) - return - } - ctx := r.Context() - ctx = mq.SetHeaders(ctx, map[string]string{consts.AwaitResponseKey: "true"}) - rs := tm.Process(ctx, payload) - if rs.Error != nil { - http.Error(w, fmt.Sprintf("[DAG Error] - %v", rs.Error), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(rs) -} - func (tm *DAG) Start(ctx context.Context, addr string) error { if !tm.server.SyncMode() { go func() { @@ -236,6 +204,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error { } } log.Printf("DAG - HTTP_SERVER ~> started on %s", addr) + tm.Handlers() config := tm.server.TLSConfig() if config.UseTLS { return http.ListenAndServeTLS(addr, config.CertPath, config.KeyPath, nil) diff --git a/dag/ui.go b/dag/ui.go index a791e40..118b0f3 100644 --- a/dag/ui.go +++ b/dag/ui.go @@ -73,7 +73,6 @@ func (tm *DAG) dfs(v string, visited map[string]bool, discoveryTime, finishedTim tm.dfs(adj.Key, visited, discoveryTime, finishedTime, timeVal) } } - } } tm.handleConditionalEdges(v, visited, discoveryTime, finishedTime, timeVal) @@ -129,8 +128,6 @@ func (tm *DAG) ExportDOT() string { var sb strings.Builder sb.WriteString(fmt.Sprintf(`digraph "%s" {`, tm.name)) sb.WriteString("\n") - sb.WriteString(` bgcolor="lightyellow";`) - sb.WriteString("\n") sb.WriteString(fmt.Sprintf(` label="%s";`, tm.name)) sb.WriteString("\n") sb.WriteString(` labelloc="t";`) diff --git a/examples/dag.go b/examples/dag.go index 10bcc80..f483070 100644 --- a/examples/dag.go +++ b/examples/dag.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "github.com/oarkflow/mq/examples/tasks" "github.com/oarkflow/mq/services" @@ -45,35 +44,14 @@ func sendData(f *dag.DAG) { } func Sync() { - f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse)) + f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithSyncMode(true), mq.WithNotifyResponse(tasks.NotifyResponse)) setup(f) - fmt.Println(f.ExportDOT()) sendData(f) - fmt.Println(f.SaveSVG("dag.svg")) } func aSync() { - f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithNotifyResponse(tasks.NotifyResponse)) + f := dag.NewDAG("Sample DAG", "sample-dag", mq.WithCleanTaskOnComplete(), mq.WithNotifyResponse(tasks.NotifyResponse)) setup(f) - http.HandleFunc("POST /request", f.ServeHTTP) - http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { - id := request.PathValue("id") - if id != "" { - f.PauseConsumer(request.Context(), id) - } - }) - http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) { - id := request.PathValue("id") - if id != "" { - f.ResumeConsumer(request.Context(), id) - } - }) - http.HandleFunc("/pause", func(writer http.ResponseWriter, request *http.Request) { - f.Pause(request.Context()) - }) - http.HandleFunc("/resume", func(writer http.ResponseWriter, request *http.Request) { - f.Resume(request.Context()) - }) err := f.Start(context.TODO(), ":8083") if err != nil { panic(err) diff --git a/options.go b/options.go index 2f47c88..76cf5c1 100644 --- a/options.go +++ b/options.go @@ -79,6 +79,7 @@ type Options struct { numOfWorkers int maxMemoryLoad int64 syncMode bool + cleanTaskOnComplete bool enableWorkerPool bool respondPendingResult bool } @@ -95,6 +96,10 @@ func (o *Options) Storage() TaskStorage { return o.storage } +func (o *Options) CleanTaskOnComplete() bool { + return o.cleanTaskOnComplete +} + func (o *Options) QueueSize() int { return o.queueSize } @@ -186,6 +191,13 @@ func WithSyncMode(mode bool) Option { } } +// WithCleanTaskOnComplete - +func WithCleanTaskOnComplete() Option { + return func(opts *Options) { + opts.cleanTaskOnComplete = true + } +} + // WithRespondPendingResult - func WithRespondPendingResult(mode bool) Option { return func(opts *Options) { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 555a68a..622dcc3 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -65,7 +65,7 @@ func (g *Map[K, V]) Size() int { // Keys returns a slice of all keys in the map func (g *Map[K, V]) Keys() []K { - keys := []K{} + var keys []K g.ForEach(func(k K, _ V) bool { keys = append(keys, k) return true @@ -75,7 +75,7 @@ func (g *Map[K, V]) Keys() []K { // Values returns a slice of all values in the map func (g *Map[K, V]) Values() []V { - values := []V{} + var values []V g.ForEach(func(_ K, v V) bool { values = append(values, v) return true