diff --git a/business/jxutils/tasksch/parallel_task.go b/business/jxutils/tasksch/parallel_task.go index 3eecc6cae..18d012c62 100644 --- a/business/jxutils/tasksch/parallel_task.go +++ b/business/jxutils/tasksch/parallel_task.go @@ -86,7 +86,7 @@ func NewParallelTask(taskName string, config *ParallelConfig, userName string, w } task := &ParallelTask{ subFinishChan: make(chan interface{}, config.ParallelCount), - taskChan: make(chan []interface{}, len(realItemList)+config.ParallelCount), // 确保能装下所有taskitem,加结束标记 + taskChan: make(chan []interface{}, len(realItemList)), resultHandler: config.ResultHandler, worker: worker, jobList: jobList, @@ -114,8 +114,8 @@ func (task *ParallelTask) Run() { goto end default: select { - case job := <-task.taskChan: - if job == nil { // 任务完成 + case job, ok := <-task.taskChan: + if !ok { // 任务完成 chanRetVal = retVal goto end } else { @@ -149,9 +149,7 @@ func (task *ParallelTask) Run() { for _, job := range task.jobList { task.taskChan <- job } - for i := 0; i < task.ParallelCount; i++ { - task.taskChan <- nil - } + close(task.taskChan) taskResult := make([]interface{}, 0) var taskErr error @@ -170,9 +168,6 @@ func (task *ParallelTask) Run() { } task.locker.Lock() - if task.Status != TaskStatusCanceling { - close(task.quitChan) - } if taskErr != nil { // 如果有错误,肯定就是失败了 task.Status = TaskStatusFailed } else { @@ -190,7 +185,6 @@ func (task *ParallelTask) Run() { globals.SugarLogger.Debugf("ParallelTask.Run %s, result:%v, err:%v", task.Name, taskResult, taskErr) - close(task.finishChan) close(task.subFinishChan) if task.resultHandler != nil { diff --git a/business/jxutils/tasksch/sequence_task.go b/business/jxutils/tasksch/sequence_task.go index 15912d9ea..5e13fea24 100644 --- a/business/jxutils/tasksch/sequence_task.go +++ b/business/jxutils/tasksch/sequence_task.go @@ -57,9 +57,6 @@ func (task *SeqTask) Run() { } EndFor: task.locker.Lock() - if task.Status != TaskStatusCanceling { - close(task.quitChan) - } if taskErr != nil { // 如果有错误,肯定就是失败了 task.Status = TaskStatusFailed } else { @@ -76,8 +73,6 @@ func (task *SeqTask) Run() { task.locker.Unlock() globals.SugarLogger.Debugf("SeqTask.Run %s, result:%v, err:%v", task.Name, taskResult, taskErr) - - close(task.finishChan) }) } diff --git a/business/jxutils/tasksch/task.go b/business/jxutils/tasksch/task.go index 3ad2317b2..29d0f4cff 100644 --- a/business/jxutils/tasksch/task.go +++ b/business/jxutils/tasksch/task.go @@ -101,8 +101,8 @@ func (t *BaseTask) Init(parallelCount, batchSize int, isContinueWhenError bool, t.TerminatedAt = utils.DefaultTimeValue t.TotalItemCount = totalItemCount t.TotalJobCount = totalJobCount - t.quitChan = make(chan int, 1) - t.finishChan = make(chan int, 1) + t.quitChan = make(chan int) + t.finishChan = make(chan int) t.Status = TaskStatusWorking t.C = t.finishChan @@ -228,11 +228,17 @@ func (t *BaseTask) run(taskHandler func()) { }() taskHandler() + select { + case <-t.quitChan: + default: + close(t.quitChan) + } for _, subTask := range t.Children { if _, err := subTask.GetResult(0); err != nil { globals.SugarLogger.Warnf("BaseTask run, failed with error:%v", err) } } + close(t.finishChan) }) }