From 66ef068fc38cdc2ed751f2de31ad05cd778e129c Mon Sep 17 00:00:00 2001 From: gazebo Date: Sat, 20 Oct 2018 09:22:36 +0800 Subject: [PATCH] - refactor tasksch - sequence task added. - task tree added. --- .../{parellel_task.go => parallel_task.go} | 81 ++++++++++--------- .../{task_test.go => parallel_task_test.go} | 0 business/jxutils/tasksch/sequence_task.go | 53 ++++++------ .../jxutils/tasksch/sequence_task_test.go | 52 ++++++++++++ business/jxutils/tasksch/task.go | 77 +++++++++++++++--- business/jxutils/tasksch/task_man.go | 11 +-- 6 files changed, 191 insertions(+), 83 deletions(-) rename business/jxutils/tasksch/{parellel_task.go => parallel_task.go} (71%) rename business/jxutils/tasksch/{task_test.go => parallel_task_test.go} (100%) create mode 100644 business/jxutils/tasksch/sequence_task_test.go diff --git a/business/jxutils/tasksch/parellel_task.go b/business/jxutils/tasksch/parallel_task.go similarity index 71% rename from business/jxutils/tasksch/parellel_task.go rename to business/jxutils/tasksch/parallel_task.go index 18b96981e..e2235bee3 100644 --- a/business/jxutils/tasksch/parellel_task.go +++ b/business/jxutils/tasksch/parallel_task.go @@ -48,6 +48,26 @@ func NewParallelConfig() *ParallelConfig { } } +func (c *ParallelConfig) SetParallelCount(parallelCount int) *ParallelConfig { + c.ParallelCount = parallelCount + return c +} + +func (c *ParallelConfig) SetBatchSize(batchSize int) *ParallelConfig { + c.BatchSize = batchSize + return c +} + +func (c *ParallelConfig) SetIsContinueWhenError(isContinueWhenError bool) *ParallelConfig { + c.IsContinueWhenError = isContinueWhenError + return c +} + +func (c *ParallelConfig) SetResultHandler(resultHandler ResultHandlerFunc) *ParallelConfig { + c.ResultHandler = resultHandler + return c +} + func NewParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *Task { if config == nil { config = NewParallelConfig() @@ -60,42 +80,25 @@ func NewParallelTask(taskName string, userName string, config *ParallelConfig, w } realItemList := utils.Interface2Slice(itemList) jobList := jxutils.SplitSlice(realItemList, config.BatchSize) - jobListLen := jxutils.GetSliceLen(jobList) + jobListLen := len(jobList) if config.ParallelCount > jobListLen { config.ParallelCount = jobListLen } - task := &Task{ - BaseTask: BaseTask{ - ParallelCount: config.ParallelCount, - BatchSize: config.BatchSize, - IsContinueWhenError: config.IsContinueWhenError, - params: params, - - ID: utils.GetUUID(), - Name: taskName, - CreatedAt: time.Now(), - CreatedBy: userName, - UpdatedAt: time.Now(), - TotalJobCount: len(jobList), - TotalItemCount: len(realItemList), - quitChan: make(chan int, config.ParallelCount), - finishChan: make(chan int, 2), - Status: TaskStatusWorking, - }, subFinishChan: make(chan interface{}, config.ParallelCount), taskChan: make(chan []interface{}, len(realItemList)+config.ParallelCount), // 确保能装下所有taskitem,加结束标记 resultHandler: config.ResultHandler, worker: worker, jobList: jobList, } - task.C = task.finishChan + task.Init(config.ParallelCount, config.BatchSize, config.IsContinueWhenError, params, taskName, userName, len(realItemList), jobListLen) return task } func RunParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *Task { task := NewParallelTask(taskName, userName, config, worker, itemList, params...) - return task.Run() + task.Run() + return task } func RunTask(taskName string, isContinueWhenError bool, resultHandler ResultHandlerFunc, parallelCount, batchSize int, userName string, worker WorkFunc, itemList interface{}, params ...interface{}) *Task { @@ -109,9 +112,9 @@ func RunTask(taskName string, isContinueWhenError bool, resultHandler ResultHand return task } -func (task *Task) Run() *Task { - go func() { - globals.SugarLogger.Debugf("Run ParallelTask %s", task.Name) +func (task *Task) Run() { + task.run(func() { + globals.SugarLogger.Debugf("ParallelTask.Run %s", task.Name) for i := 0; i < task.ParallelCount; i++ { go func() { var chanRetVal interface{} @@ -126,21 +129,24 @@ func (task *Task) Run() *Task { goto end } else { result, err := task.worker(job, task.params...) - globals.SugarLogger.Debugf("RunTask %s, after call worker result:%v, err:%v", task.Name, result, err) + // globals.SugarLogger.Debugf("ParallelTask.Run %s, after call worker result:%v, err:%v", task.Name, result, err) task.finishedOneJob(len(job), err) if err == nil { if result != nil { retVal = append(retVal, utils.Interface2Slice(result)...) } - } else if !task.IsContinueWhenError { // 出错 - chanRetVal = err - goto end + } else { + globals.SugarLogger.Infof("ParallelTask.Run %s, subtask(job:%s, params:%s) result:%v, failed with error:%v", task.Name, utils.Format4Output(job, true), utils.Format4Output(task.params, true), result, err) + if !task.IsContinueWhenError { // 出错 + chanRetVal = err + goto end + } } } } } end: - globals.SugarLogger.Debugf("RunTask %s, put to chann chanRetVal:%v", task.Name, chanRetVal) + // globals.SugarLogger.Debugf("ParallelTask.Run %s, put to chann chanRetVal:%v", task.Name, chanRetVal) task.locker.RLock() if task.Status < TaskStatusEndBegin { task.subFinishChan <- chanRetVal @@ -159,11 +165,11 @@ func (task *Task) Run() *Task { var taskErr error for i := 0; i < task.ParallelCount; i++ { result := <-task.subFinishChan - // globals.SugarLogger.Debugf("RunTask %s, received from chann result:%v", taskName, result) - if err2, ok := result.(error); ok { + // globals.SugarLogger.Debugf("ParallelTask.Run %s, received from chann result:%v", taskName, result) + if err, ok := result.(error); ok { task.Cancel() taskResult = nil - taskErr = err2 + taskErr = err break // 出错情况下是否需要直接跳出? } else if result != nil { resultList := result.([]interface{}) @@ -182,20 +188,19 @@ func (task *Task) Run() *Task { task.Status = TaskStatusFinished } } - task.err = taskErr - task.result = taskResult + task.Err = taskErr + task.Result = taskResult task.TerminatedAt = time.Now() task.locker.Unlock() - globals.SugarLogger.Debugf("RunTask %s, result:%v, err:%v", task.Name, taskResult, taskErr) + globals.SugarLogger.Debugf("ParallelTask.Run %s, result:%v, err:%v", task.Name, taskResult, taskErr) close(task.finishChan) close(task.subFinishChan) close(task.quitChan) if task.resultHandler != nil { - task.resultHandler(task.Name, taskResult, task.err) + task.resultHandler(task.Name, taskResult, task.Err) } - }() - return task + }) } diff --git a/business/jxutils/tasksch/task_test.go b/business/jxutils/tasksch/parallel_task_test.go similarity index 100% rename from business/jxutils/tasksch/task_test.go rename to business/jxutils/tasksch/parallel_task_test.go diff --git a/business/jxutils/tasksch/sequence_task.go b/business/jxutils/tasksch/sequence_task.go index 979a3bebd..9fe04f151 100644 --- a/business/jxutils/tasksch/sequence_task.go +++ b/business/jxutils/tasksch/sequence_task.go @@ -21,39 +21,35 @@ type SeqTask struct { func NewSeqTask(taskName string, userName string, worker SeqWorkFunc, stepCount int, params ...interface{}) *SeqTask { task := &SeqTask{ - BaseTask: BaseTask{ - ParallelCount: 1, - params: params, - ID: utils.GetUUID(), - Name: taskName, - CreatedAt: time.Now(), - CreatedBy: userName, - UpdatedAt: time.Now(), - TotalJobCount: stepCount, - TotalItemCount: stepCount, - quitChan: make(chan int, 1), - finishChan: make(chan int, 2), - Status: TaskStatusWorking, - }, worker: worker, } - task.C = task.finishChan + task.Init(1, 1, false, params, taskName, userName, stepCount, stepCount) return task } -func (task *SeqTask) Run() *SeqTask { - go func() { - globals.SugarLogger.Debugf("Run SeqTask %s", task.Name) +func (task *SeqTask) Run() { + task.run(func() { + globals.SugarLogger.Debugf("SeqTask.Run %s", task.Name) var taskErr error - var taskResult interface{} + var taskResult []interface{} for i := 0; i < task.TotalItemCount; i++ { - taskResult, taskErr = task.worker(i, task.params...) - task.finishedOneJob(1, taskErr) - if taskErr != nil { - break + select { + case <-task.quitChan: + goto EndFor + default: + } + result, err := task.worker(i, task.params...) + task.finishedOneJob(1, err) + if taskErr = err; taskErr != nil { + globals.SugarLogger.Infof("SeqTask.Run %s step:%d failed with error:%v", task.Name, i, err) + if !task.IsContinueWhenError { + break + } + } else if result != nil { + taskResult = append(taskResult, utils.Interface2Slice(result)...) } } - + EndFor: task.locker.Lock() if taskErr != nil { // 如果有错误,肯定就是失败了 task.Status = TaskStatusFailed @@ -65,15 +61,14 @@ func (task *SeqTask) Run() *SeqTask { task.Status = TaskStatusFinished } } - task.err = taskErr - task.result = taskResult + task.Err = taskErr + task.Result = taskResult task.TerminatedAt = time.Now() task.locker.Unlock() - globals.SugarLogger.Debugf("Run SeqTask %s, result:%v, err:%v", task.Name, taskResult, taskErr) + globals.SugarLogger.Debugf("SeqTask.Run %s, result:%v, err:%v", task.Name, taskResult, taskErr) close(task.finishChan) close(task.quitChan) - }() - return task + }) } diff --git a/business/jxutils/tasksch/sequence_task_test.go b/business/jxutils/tasksch/sequence_task_test.go new file mode 100644 index 000000000..fc3ea8d7b --- /dev/null +++ b/business/jxutils/tasksch/sequence_task_test.go @@ -0,0 +1,52 @@ +package tasksch + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "git.rosy.net.cn/baseapi/utils" +) + +func TestRunSeqTask(t *testing.T) { + var seqTask ITask + seqTask = NewSeqTask("TestSeqTask", "autotest", func(step int, params ...interface{}) (result interface{}, err error) { + switch step { + case 0: + fmt.Println("ONE") + task2 := NewParallelTask("hello", "xjh", nil, func(batchItemList []interface{}, params ...interface{}) (retVal interface{}, err error) { + i := batchItemList[0].(int) + time.Sleep(2 * time.Second) + fmt.Println(i * 2) + return nil, nil + }, []int{1, 2, 3}) + seqTask.AddChild(task2) + time.Sleep(time.Duration(rand.Intn(3)) * time.Second) + task2.Run() + case 1: + fmt.Println("TWO") + time.Sleep(time.Duration(rand.Intn(3)) * time.Second) + case 2: + fmt.Println("THREE") + time.Sleep(time.Duration(rand.Intn(3)) * time.Second) + case 3: + fmt.Println("FOUR") + time.Sleep(time.Duration(rand.Intn(3)) * time.Second) + case 4: + fmt.Println("FIVE") + time.Sleep(time.Duration(rand.Intn(3)) * time.Second) + } + return []string{"1"}, nil + }, 5) + + seqTask.Run() + time.Sleep(3 * time.Second) + seqTask.Cancel() + fmt.Println(utils.Format4Output(seqTask, false)) + result, err := seqTask.GetResult(0) + if err != nil { + t.Fatal(err) + } + t.Log(result) +} diff --git a/business/jxutils/tasksch/task.go b/business/jxutils/tasksch/task.go index 95cde054e..05bb49412 100644 --- a/business/jxutils/tasksch/task.go +++ b/business/jxutils/tasksch/task.go @@ -3,6 +3,9 @@ package tasksch import ( "sync" "time" + + "git.rosy.net.cn/baseapi/utils" + "git.rosy.net.cn/jx-callback/globals" ) const ( @@ -17,8 +20,10 @@ const ( TaskStatusEnd = 4 ) +type TaskList []ITask + type ITask interface { - Run() *ITask + Run() GetResult(duration time.Duration) (retVal []interface{}, err error) Cancel() GetTotalItemCount() int @@ -26,6 +31,9 @@ type ITask interface { GetTotalJobCount() int GetFinishedJobCount() int GetStatus() int + GetCreatedAt() time.Time + + AddChild(task ITask) } type BaseTask struct { @@ -46,24 +54,24 @@ type BaseTask struct { FailedJobCount int `json:"failedJobCount"` Status int `json:"status"` + Result []interface{} `json:"result"` + Err error `json:"err"` + Children TaskList `json:"children"` + finishChan chan int C <-chan int `json:"-"` params []interface{} quitChan chan int locker sync.RWMutex - result interface{} - err error } -type TaskList []*Task - func (s TaskList) Len() int { return len(s) } func (s TaskList) Less(i, j int) bool { - return s[i].CreatedAt.Sub(s[j].CreatedAt) < 0 + return s[i].GetCreatedAt().Sub(s[j].GetCreatedAt()) < 0 } func (s TaskList) Swap(i, j int) { @@ -72,10 +80,29 @@ func (s TaskList) Swap(i, j int) { s[j] = tmp } +func (t *BaseTask) Init(parallelCount, batchSize int, isContinueWhenError bool, params []interface{}, name, userName string, totalItemCount, totalJobCount int) { + t.ID = utils.GetUUID() + t.ParallelCount = parallelCount + t.BatchSize = batchSize + t.IsContinueWhenError = isContinueWhenError + t.params = params + t.Name = name + t.CreatedAt = time.Now() + t.CreatedBy = userName + t.UpdatedAt = t.CreatedAt + t.TerminatedAt = utils.DefaultTimeValue + t.TotalItemCount = totalItemCount + t.TotalJobCount = totalJobCount + t.quitChan = make(chan int, parallelCount) + t.finishChan = make(chan int, 2) + t.Status = TaskStatusWorking + + t.C = t.finishChan +} + func (t *BaseTask) GetResult(duration time.Duration) (retVal []interface{}, err error) { if t.GetStatus() >= TaskStatusEndBegin { - retVal, _ = t.result.([]interface{}) - return retVal, t.err + return t.Result, t.Err } if duration == 0 { duration = time.Hour * 10000 // duration为0表示无限等待 @@ -87,22 +114,32 @@ func (t *BaseTask) GetResult(duration time.Duration) (retVal []interface{}, err t.locker.RLock() defer t.locker.RUnlock() - retVal, _ = t.result.([]interface{}) - return retVal, t.err + return t.Result, t.Err case <-timer.C: } return nil, ErrTaskNotFinished } +func (t *BaseTask) GetCreatedAt() time.Time { + t.locker.RLock() + defer t.locker.RUnlock() + + return t.CreatedAt +} + func (t *BaseTask) Cancel() { t.locker.Lock() - defer t.locker.Unlock() if t.Status < TaskStatusEndBegin && t.Status != TaskStatusCanceling { t.Status = TaskStatusCanceling for i := 0; i < t.ParallelCount; i++ { t.quitChan <- 0 } } + t.locker.Unlock() + + for _, subTask := range t.Children { + subTask.Cancel() + } } func (t *BaseTask) GetTotalItemCount() int { @@ -134,6 +171,24 @@ func (t *BaseTask) GetStatus() int { return t.Status } +func (t *BaseTask) AddChild(task ITask) { + t.locker.Lock() + defer t.locker.Unlock() + + t.Children = append(t.Children, task) +} + +func (t *BaseTask) run(taskHandler func()) { + go func() { + taskHandler() + for _, subTask := range t.Children { + if _, err := subTask.GetResult(0); err != nil { + globals.SugarLogger.Warnf("BaseTask run, failed with error:%v", err) + } + } + }() +} + ///////// func (t *BaseTask) finishedOneJob(itemCount int, err error) { t.locker.Lock() diff --git a/business/jxutils/tasksch/task_man.go b/business/jxutils/tasksch/task_man.go index d3ead321f..22365f893 100644 --- a/business/jxutils/tasksch/task_man.go +++ b/business/jxutils/tasksch/task_man.go @@ -11,11 +11,11 @@ var ( ) type TaskMan struct { - taskList map[string]*Task + taskList map[string]ITask } func init() { - defTaskMan.taskList = make(map[string]*Task) + defTaskMan.taskList = make(map[string]ITask) } func (m *TaskMan) RunTask(taskName string, isContinueWhenError bool, resultHandler ResultHandlerFunc, parallelCount, batchSize int, userName string, worker WorkFunc, itemList interface{}, params ...interface{}) *Task { @@ -24,13 +24,14 @@ func (m *TaskMan) RunTask(taskName string, isContinueWhenError bool, resultHandl return task } -func (m *TaskMan) GetTasks(taskID string, fromStatus, toStatus int, lastHours int) (taskList []*Task) { +func (m *TaskMan) GetTasks(taskID string, fromStatus, toStatus int, lastHours int) (taskList TaskList) { if lastHours == 0 { lastHours = defLastHours } lastTime := time.Now().Add(time.Duration(-lastHours) * time.Hour).Unix() for k, v := range m.taskList { - if !((taskID != "" && taskID != k) || v.Status < fromStatus || v.Status > toStatus || v.CreatedAt.Unix() < lastTime) { + status := v.GetStatus() + if !((taskID != "" && taskID != k) || status < fromStatus || status > toStatus || v.GetCreatedAt().Unix() < lastTime) { taskList = append(taskList, v) } } @@ -42,6 +43,6 @@ func RunManagedTask(taskName string, isContinueWhenError bool, resultHandler Res return defTaskMan.RunTask(taskName, isContinueWhenError, resultHandler, parallelCount, batchSize, userName, worker, itemList, params...) } -func GetTasks(taskID string, fromStatus, toStatus int, lastHours int) (taskList []*Task) { +func GetTasks(taskID string, fromStatus, toStatus int, lastHours int) (taskList TaskList) { return defTaskMan.GetTasks(taskID, fromStatus, toStatus, lastHours) }