Files
jx-callback/business/jxutils/tasksch/parallel_task.go
2018-10-20 12:02:52 +08:00

207 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tasksch
import (
"errors"
"time"
"git.rosy.net.cn/baseapi/utils"
"git.rosy.net.cn/jx-callback/business/jxutils"
"git.rosy.net.cn/jx-callback/globals"
)
const (
DefParallelCount = 10
MaxParallelCount = 10
)
type WorkFunc func(batchItemList []interface{}, params ...interface{}) (retVal interface{}, err error)
type ResultHandlerFunc func(taskName string, result []interface{}, err error)
type ParallelConfig struct {
ParallelCount int
BatchSize int
IsContinueWhenError bool
ResultHandler ResultHandlerFunc
}
type ParallelTask struct {
BaseTask
resultHandler ResultHandlerFunc
worker WorkFunc
jobList [][]interface{}
taskChan chan []interface{}
subFinishChan chan interface{}
}
var (
ErrTaskNotFinished = errors.New("任务还未完成")
ErrTaskIsCanceled = errors.New("任务被取消了")
)
func NewParallelConfig() *ParallelConfig {
return &ParallelConfig{
ParallelCount: DefParallelCount,
BatchSize: 1,
IsContinueWhenError: false,
ResultHandler: nil,
}
}
func (c *ParallelConfig) SetParallelCount(parallelCount int) *ParallelConfig {
c.ParallelCount = parallelCount
return c
}
func (c *ParallelConfig) SetBatchSize(batchSize int) *ParallelConfig {
c.BatchSize = batchSize
return c
}
func (c *ParallelConfig) SetIsContinueWhenError(isContinueWhenError bool) *ParallelConfig {
c.IsContinueWhenError = isContinueWhenError
return c
}
func (c *ParallelConfig) SetResultHandler(resultHandler ResultHandlerFunc) *ParallelConfig {
c.ResultHandler = resultHandler
return c
}
func NewParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *ParallelTask {
if config == nil {
config = NewParallelConfig()
}
if config.ParallelCount == 0 {
config.ParallelCount = DefParallelCount
}
if config.ParallelCount > MaxParallelCount {
config.ParallelCount = MaxParallelCount
}
realItemList := utils.Interface2Slice(itemList)
jobList := jxutils.SplitSlice(realItemList, config.BatchSize)
jobListLen := len(jobList)
if config.ParallelCount > jobListLen {
config.ParallelCount = jobListLen
}
task := &ParallelTask{
subFinishChan: make(chan interface{}, config.ParallelCount),
taskChan: make(chan []interface{}, len(realItemList)+config.ParallelCount), // 确保能装下所有taskitem加结束标记
resultHandler: config.ResultHandler,
worker: worker,
jobList: jobList,
}
task.Init(config.ParallelCount, config.BatchSize, config.IsContinueWhenError, params, taskName, userName, len(realItemList), jobListLen)
return task
}
func RunParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *ParallelTask {
task := NewParallelTask(taskName, userName, config, worker, itemList, params...)
task.Run()
return task
}
func RunTask(taskName string, isContinueWhenError bool, resultHandler ResultHandlerFunc, parallelCount, batchSize int, userName string, worker WorkFunc, itemList interface{}, params ...interface{}) *ParallelTask {
config := NewParallelConfig()
config.BatchSize = batchSize
config.IsContinueWhenError = isContinueWhenError
config.ParallelCount = parallelCount
config.ResultHandler = resultHandler
task := NewParallelTask(taskName, userName, config, worker, itemList, params...)
task.Run()
return task
}
func (task *ParallelTask) Run() {
task.run(func() {
globals.SugarLogger.Debugf("ParallelTask.Run %s", task.Name)
for i := 0; i < task.ParallelCount; i++ {
utils.CallFuncAsync(func() {
var chanRetVal interface{}
retVal := make([]interface{}, 0)
for {
select {
case <-task.quitChan: // 取消
goto end
case job := <-task.taskChan:
if job == nil { // 任务完成
chanRetVal = retVal
goto end
} else {
result, err := task.worker(job, task.params...)
// globals.SugarLogger.Debugf("ParallelTask.Run %s, after call worker result:%v, err:%v", task.Name, result, err)
task.finishedOneJob(len(job), err)
if err == nil {
if result != nil {
retVal = append(retVal, utils.Interface2Slice(result)...)
}
} else {
globals.SugarLogger.Infof("ParallelTask.Run %s, subtask(job:%s, params:%s) result:%v, failed with error:%v", task.Name, utils.Format4Output(job, true), utils.Format4Output(task.params, true), result, err)
if !task.IsContinueWhenError { // 出错
chanRetVal = err
goto end
}
}
}
}
}
end:
// globals.SugarLogger.Debugf("ParallelTask.Run %s, put to chann chanRetVal:%v", task.Name, chanRetVal)
task.locker.RLock()
if task.Status < TaskStatusEndBegin {
task.subFinishChan <- chanRetVal
}
task.locker.RUnlock()
})
}
for _, job := range task.jobList {
task.taskChan <- job
}
for i := 0; i < task.ParallelCount; i++ {
task.taskChan <- nil
}
taskResult := make([]interface{}, 0)
var taskErr error
for i := 0; i < task.ParallelCount; i++ {
result := <-task.subFinishChan
// globals.SugarLogger.Debugf("ParallelTask.Run %s, received from chann result:%v", taskName, result)
if err, ok := result.(error); ok {
task.Cancel()
taskResult = nil
taskErr = err
break // 出错情况下是否需要直接跳出?
} else if result != nil {
resultList := result.([]interface{})
taskResult = append(taskResult, resultList...)
}
}
task.locker.Lock()
if taskErr != nil { // 如果有错误,肯定就是失败了
task.Status = TaskStatusFailed
} else {
if len(task.taskChan) > 0 {
taskErr = ErrTaskIsCanceled
task.Status = TaskStatusCanceled
} else {
task.Status = TaskStatusFinished
}
}
task.Err = taskErr
task.Result = taskResult
task.TerminatedAt = time.Now()
task.locker.Unlock()
globals.SugarLogger.Debugf("ParallelTask.Run %s, result:%v, err:%v", task.Name, taskResult, taskErr)
close(task.finishChan)
close(task.subFinishChan)
close(task.quitChan)
if task.resultHandler != nil {
task.resultHandler(task.Name, taskResult, task.Err)
}
})
}