package tasksch import ( "errors" "time" "git.rosy.net.cn/baseapi/utils" "git.rosy.net.cn/jx-callback/business/jxutils" "git.rosy.net.cn/jx-callback/globals" ) const ( DefParallelCount = 10 MaxParallelCount = 10 ) type WorkFunc func(batchItemList []interface{}, params ...interface{}) (retVal interface{}, err error) type ResultHandlerFunc func(taskName string, result []interface{}, err error) type ParallelConfig struct { ParallelCount int BatchSize int IsContinueWhenError bool ResultHandler ResultHandlerFunc } type ParallelTask struct { BaseTask resultHandler ResultHandlerFunc worker WorkFunc jobList [][]interface{} taskChan chan []interface{} subFinishChan chan interface{} } var ( ErrTaskNotFinished = errors.New("任务还未完成") ErrTaskIsCanceled = errors.New("任务被取消了") ) func NewParallelConfig() *ParallelConfig { return &ParallelConfig{ ParallelCount: DefParallelCount, BatchSize: 1, IsContinueWhenError: false, ResultHandler: nil, } } func (c *ParallelConfig) SetParallelCount(parallelCount int) *ParallelConfig { c.ParallelCount = parallelCount return c } func (c *ParallelConfig) SetBatchSize(batchSize int) *ParallelConfig { c.BatchSize = batchSize return c } func (c *ParallelConfig) SetIsContinueWhenError(isContinueWhenError bool) *ParallelConfig { c.IsContinueWhenError = isContinueWhenError return c } func (c *ParallelConfig) SetResultHandler(resultHandler ResultHandlerFunc) *ParallelConfig { c.ResultHandler = resultHandler return c } func NewParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *ParallelTask { if config == nil { config = NewParallelConfig() } if config.ParallelCount == 0 { config.ParallelCount = DefParallelCount } if config.ParallelCount > MaxParallelCount { config.ParallelCount = MaxParallelCount } realItemList := utils.Interface2Slice(itemList) jobList := jxutils.SplitSlice(realItemList, config.BatchSize) jobListLen := len(jobList) if config.ParallelCount > jobListLen { config.ParallelCount = jobListLen } task := &ParallelTask{ subFinishChan: make(chan interface{}, config.ParallelCount), taskChan: make(chan []interface{}, len(realItemList)+config.ParallelCount), // 确保能装下所有taskitem,加结束标记 resultHandler: config.ResultHandler, worker: worker, jobList: jobList, } task.Init(config.ParallelCount, config.BatchSize, config.IsContinueWhenError, params, taskName, userName, len(realItemList), jobListLen) return task } func RunParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *ParallelTask { task := NewParallelTask(taskName, userName, config, worker, itemList, params...) task.Run() return task } func RunTask(taskName string, isContinueWhenError bool, resultHandler ResultHandlerFunc, parallelCount, batchSize int, userName string, worker WorkFunc, itemList interface{}, params ...interface{}) *ParallelTask { config := NewParallelConfig() config.BatchSize = batchSize config.IsContinueWhenError = isContinueWhenError config.ParallelCount = parallelCount config.ResultHandler = resultHandler task := NewParallelTask(taskName, userName, config, worker, itemList, params...) task.Run() return task } func (task *ParallelTask) Run() { task.run(func() { globals.SugarLogger.Debugf("ParallelTask.Run %s", task.Name) for i := 0; i < task.ParallelCount; i++ { utils.CallFuncAsync(func() { var chanRetVal interface{} retVal := make([]interface{}, 0) for { select { case <-task.quitChan: // 取消 goto end case job := <-task.taskChan: if job == nil { // 任务完成 chanRetVal = retVal goto end } else { result, err := task.worker(job, task.params...) // globals.SugarLogger.Debugf("ParallelTask.Run %s, after call worker result:%v, err:%v", task.Name, result, err) task.finishedOneJob(len(job), err) if err == nil { if result != nil { retVal = append(retVal, utils.Interface2Slice(result)...) } } else { globals.SugarLogger.Infof("ParallelTask.Run %s, subtask(job:%s, params:%s) result:%v, failed with error:%v", task.Name, utils.Format4Output(job, true), utils.Format4Output(task.params, true), result, err) if !task.IsContinueWhenError { // 出错 chanRetVal = err goto end } } } } } end: // globals.SugarLogger.Debugf("ParallelTask.Run %s, put to chann chanRetVal:%v", task.Name, chanRetVal) task.locker.RLock() if task.Status < TaskStatusEndBegin { task.subFinishChan <- chanRetVal } task.locker.RUnlock() }) } for _, job := range task.jobList { task.taskChan <- job } for i := 0; i < task.ParallelCount; i++ { task.taskChan <- nil } taskResult := make([]interface{}, 0) var taskErr error for i := 0; i < task.ParallelCount; i++ { result := <-task.subFinishChan // globals.SugarLogger.Debugf("ParallelTask.Run %s, received from chann result:%v", taskName, result) if err, ok := result.(error); ok { task.Cancel() taskResult = nil taskErr = err break // 出错情况下是否需要直接跳出? } else if result != nil { resultList := result.([]interface{}) taskResult = append(taskResult, resultList...) } } task.locker.Lock() if taskErr != nil { // 如果有错误,肯定就是失败了 task.Status = TaskStatusFailed } else { if len(task.taskChan) > 0 { taskErr = ErrTaskIsCanceled task.Status = TaskStatusCanceled } else { task.Status = TaskStatusFinished } } task.Err = taskErr task.Result = taskResult task.TerminatedAt = time.Now() task.locker.Unlock() globals.SugarLogger.Debugf("ParallelTask.Run %s, result:%v, err:%v", task.Name, taskResult, taskErr) close(task.finishChan) close(task.subFinishChan) close(task.quitChan) if task.resultHandler != nil { task.resultHandler(task.Name, taskResult, task.Err) } }) }