diff --git a/internal/app/logic/task.go b/internal/app/logic/task.go index c50f435..de2ef74 100644 --- a/internal/app/logic/task.go +++ b/internal/app/logic/task.go @@ -13,11 +13,10 @@ import ( ) func (t *taskLogic) RunTaskById(id int) error { - v, ok := t.taskJobMap.Load(id) - if !ok { - return errors.New("don't exist task id") + task, err := t.getTaskJob(id) + if err != nil { + return errors.New("id不存在") } - task := v.(*model.TaskJob) if task.Running { return errors.New("task is running") } @@ -68,9 +67,8 @@ func (t *taskLogic) run(ctx context.Context, data *model.TaskJob) { log.Logger.Infow("任务执行成功", "target", data.Task.OperationTarget) if data.Task.NextId != nil { - v, ok := t.taskJobMap.Load(*data.Task.NextId) - nextTask := v.(*model.TaskJob) - if !ok { + nextTask, err := t.getTaskJob(*data.Task.NextId) + if err != nil { log.Logger.Errorw("无法获取到下一个节点,结束任务", "nextId", data.Task.NextId) return } diff --git a/internal/app/logic/task_logic.go b/internal/app/logic/task_logic.go index b8c4228..9761950 100644 --- a/internal/app/logic/task_logic.go +++ b/internal/app/logic/task_logic.go @@ -21,6 +21,14 @@ type taskLogic struct { var TaskLogic taskLogic +func (t *taskLogic) getTaskJob(id int) (*model.TaskJob, error) { + c, ok := t.taskJobMap.Load(id) + if !ok { + return nil, errors.New("don't exist this task id") + } + return c.(*model.TaskJob), nil +} + func (t *taskLogic) InitTaskJob() { for _, v := range repository.TaskRepository.GetAllTask() { tj := &model.TaskJob{ @@ -61,11 +69,10 @@ func (t *taskLogic) cronHandle(data *model.TaskJob) func() { } func (t *taskLogic) StopTaskJob(id int) error { - c, ok := t.taskJobMap.Load(id) - if !ok { - return errors.New("id不存在") + taskJob, err := t.getTaskJob(id) + if err != nil { + return errors.New("don't exist this task id") } - taskJob := c.(*model.TaskJob) if taskJob.Running { taskJob.Cancel() } @@ -73,23 +80,21 @@ func (t *taskLogic) StopTaskJob(id int) error { } func (t *taskLogic) StartTaskJob(id int) error { - c, ok := t.taskJobMap.Load(id) - if !ok { - return errors.New("id不存在") + taskJob, err := t.getTaskJob(id) + if err != nil { + return errors.New("don't exist this task id") } - TaskJob := c.(*model.TaskJob) - TaskJob.Cron.Run() + taskJob.Cron.Run() return nil } func (t *taskLogic) GetAllTaskJob() []model.TaskVo { result := repository.TaskRepository.GetAllTaskWithProcessName() for i, v := range result { - item, ok := t.taskJobMap.Load(v.Id) - if !ok { + task, err := t.getTaskJob(v.Id) + if err != nil { continue } - task := item.(*model.TaskJob) result[i].Id = task.Task.Id result[i].Running = task.Running result[i].Enable = task.Task.Enable @@ -133,11 +138,10 @@ func (t *taskLogic) CreateTask(data model.Task) error { } func (t *taskLogic) EditTask(data model.Task) error { - v, ok := t.taskJobMap.Load(data.Id) - if !ok { + tj, err := t.getTaskJob(data.Id) + if err != nil { return errors.New("don't exist this task id") } - tj := v.(*model.TaskJob) if tj.Running { return errors.New("can't edit when task is running") } @@ -162,11 +166,10 @@ func (t *taskLogic) EditTask(data model.Task) error { } func (t *taskLogic) EditTaskEnable(id int, status bool) error { - v, ok := t.taskJobMap.Load(id) - if !ok { + tj, err := t.getTaskJob(id) + if err != nil { return errors.New("don't exist this task id") } - tj := v.(*model.TaskJob) tj.Task.Enable = status repository.TaskRepository.EditTaskEnable(id, status) if tj.Cron != nil {