Files
jx-callback/business/jxutils/tasksch/parallel_task.go

211 lines
5.9 KiB
Go

package tasksch
import (
"errors"
"time"
"git.rosy.net.cn/baseapi/utils"
"git.rosy.net.cn/jx-callback/business/jxutils"
"git.rosy.net.cn/jx-callback/business/jxutils/jxcontext"
"git.rosy.net.cn/jx-callback/globals"
)
const (
DefParallelCount = 10
MaxParallelCount = 50
)
type WorkFunc func(task *ParallelTask, batchItemList []interface{}, params ...interface{}) (retVal interface{}, err error)
type ResultHandlerFunc func(taskName string, result []interface{}, err error)
type ParallelConfig struct {
// ParentTask ITask
// IsAsync bool
ParallelCount int
BatchSize int
IsContinueWhenError bool
ResultHandler ResultHandlerFunc
}
type ParallelTask struct {
BaseTask
resultHandler ResultHandlerFunc
worker WorkFunc
jobList [][]interface{}
taskChan chan []interface{}
subFinishChan chan interface{}
}
var (
ErrTaskNotFinished = errors.New("任务还未完成")
ErrTaskIsCanceled = errors.New("任务被取消了")
)
func NewParallelConfig() *ParallelConfig {
return &ParallelConfig{
// ParentTask: parentTask,
// IsAsync: false,
IsContinueWhenError: false,
ParallelCount: DefParallelCount,
BatchSize: 1,
ResultHandler: nil,
}
}
func (c *ParallelConfig) SetParallelCount(parallelCount int) *ParallelConfig {
c.ParallelCount = parallelCount
return c
}
func (c *ParallelConfig) SetBatchSize(batchSize int) *ParallelConfig {
c.BatchSize = batchSize
return c
}
func (c *ParallelConfig) SetIsContinueWhenError(isContinueWhenError bool) *ParallelConfig {
c.IsContinueWhenError = isContinueWhenError
return c
}
// func (c *ParallelConfig) SetIsAsync(isAsync bool) *ParallelConfig {
// c.IsAsync = isAsync
// return c
// }
func (c *ParallelConfig) SetResultHandler(resultHandler ResultHandlerFunc) *ParallelConfig {
c.ResultHandler = resultHandler
return c
}
func NewParallelTask(taskName string, config *ParallelConfig, ctx *jxcontext.Context, worker WorkFunc, itemList interface{}, params ...interface{}) *ParallelTask {
if config == nil {
config = NewParallelConfig()
}
if config.ParallelCount == 0 {
config.ParallelCount = DefParallelCount
}
if config.ParallelCount > MaxParallelCount {
config.ParallelCount = MaxParallelCount
}
realItemList := utils.Interface2Slice(itemList)
jobList := jxutils.SplitSlice(realItemList, config.BatchSize)
jobListLen := len(jobList)
if config.ParallelCount > jobListLen {
config.ParallelCount = jobListLen
}
task := &ParallelTask{
subFinishChan: make(chan interface{}, config.ParallelCount),
taskChan: make(chan []interface{}, len(realItemList)),
resultHandler: config.ResultHandler,
worker: worker,
jobList: jobList,
}
task.Init(config.ParallelCount, config.BatchSize, config.IsContinueWhenError, params, taskName, ctx, len(realItemList), jobListLen)
return task
}
func (task *ParallelTask) Run() {
task.run(func() {
globals.SugarLogger.Debugf("ParallelTask.Run %s", task.Name)
for i := 0; i < task.ParallelCount; i++ {
utils.CallFuncAsync(func() {
var chanRetVal interface{}
retVal := make([]interface{}, 0)
for {
select {
case <-task.quitChan: // 取消
goto end
default:
select {
case job, ok := <-task.taskChan:
if !ok { // 任务完成
chanRetVal = retVal
goto end
} else {
result, err := task.callWorker(func() (retVal interface{}, err error) {
return task.worker(task, job, task.params...)
})
// globals.SugarLogger.Debugf("ParallelTask.Run %s, after call worker result:%v, err:%v", task.Name, result, err)
task.finishedOneJob(len(job), err)
if err == nil {
if result != nil {
retVal = append(retVal, utils.Interface2Slice(result)...)
}
} else {
globals.SugarLogger.Infof("ParallelTask.Run %s, subtask(job:%s, params:%s) result:%v, failed with error:%v", task.Name, utils.Format4Output(job, true), utils.Format4Output(task.params, true), result, err)
if !task.IsContinueWhenError { // 出错
chanRetVal = err
goto end
}
task.locker.Lock()
task.detailErrMsgList = append(task.detailErrMsgList, err.Error())
task.locker.Unlock()
}
}
}
}
}
end:
// globals.SugarLogger.Debugf("ParallelTask.Run %s, put to chann chanRetVal:%v", task.Name, chanRetVal)
task.locker.RLock()
if task.Status < TaskStatusEndBegin {
task.subFinishChan <- chanRetVal
}
task.locker.RUnlock()
})
}
for _, job := range task.jobList {
task.taskChan <- job
}
close(task.taskChan)
taskResult := make([]interface{}, 0)
var taskErr error
for i := 0; i < task.ParallelCount; i++ {
result := <-task.subFinishChan
// globals.SugarLogger.Debugf("ParallelTask.Run %s, received from chann result:%v", taskName, result)
if err, ok := result.(error); ok {
task.Cancel()
taskResult = nil
taskErr = err
break // 出错情况下是否需要直接跳出?
} else if result != nil {
resultList := result.([]interface{})
taskResult = append(taskResult, resultList...)
}
}
task.locker.Lock()
if taskErr != nil { // 如果有错误,肯定就是失败了
task.Status = TaskStatusFailed
} else {
if len(task.taskChan) > 0 {
taskErr = ErrTaskIsCanceled
task.Status = TaskStatusCanceled
} else {
task.Status = TaskStatusFinished
}
}
if taskErr != nil {
task.Err = NewTaskError(task.Name, taskErr)
} else {
task.Err = task.buildTaskErrFromDetail()
}
task.Result = taskResult
task.TerminatedAt = time.Now()
task.locker.Unlock()
globals.SugarLogger.Debugf("ParallelTask.Run %s, err:%v", task.Name, task.Err)
close(task.subFinishChan)
if task.resultHandler != nil {
task.resultHandler(task.Name, taskResult, task.Err)
}
})
}
func (t *ParallelTask) AddChild(task ITask) ITask {
task.SetParent(t)
return t.BaseTask.AddChild(task)
}