207 lines
6.0 KiB
Go
207 lines
6.0 KiB
Go
package tasksch
|
||
|
||
import (
|
||
"errors"
|
||
"time"
|
||
|
||
"git.rosy.net.cn/baseapi/utils"
|
||
"git.rosy.net.cn/jx-callback/business/jxutils"
|
||
"git.rosy.net.cn/jx-callback/globals"
|
||
)
|
||
|
||
const (
|
||
DefParallelCount = 10
|
||
MaxParallelCount = 10
|
||
)
|
||
|
||
type WorkFunc func(batchItemList []interface{}, params ...interface{}) (retVal interface{}, err error)
|
||
type ResultHandlerFunc func(taskName string, result []interface{}, err error)
|
||
|
||
type ParallelConfig struct {
|
||
ParallelCount int
|
||
BatchSize int
|
||
IsContinueWhenError bool
|
||
ResultHandler ResultHandlerFunc
|
||
}
|
||
|
||
type Task struct {
|
||
BaseTask
|
||
|
||
resultHandler ResultHandlerFunc
|
||
worker WorkFunc
|
||
jobList [][]interface{}
|
||
taskChan chan []interface{}
|
||
subFinishChan chan interface{}
|
||
}
|
||
|
||
var (
|
||
ErrTaskNotFinished = errors.New("任务还未完成")
|
||
ErrTaskIsCanceled = errors.New("任务被取消了")
|
||
)
|
||
|
||
func NewParallelConfig() *ParallelConfig {
|
||
return &ParallelConfig{
|
||
ParallelCount: DefParallelCount,
|
||
BatchSize: 1,
|
||
IsContinueWhenError: false,
|
||
ResultHandler: nil,
|
||
}
|
||
}
|
||
|
||
func (c *ParallelConfig) SetParallelCount(parallelCount int) *ParallelConfig {
|
||
c.ParallelCount = parallelCount
|
||
return c
|
||
}
|
||
|
||
func (c *ParallelConfig) SetBatchSize(batchSize int) *ParallelConfig {
|
||
c.BatchSize = batchSize
|
||
return c
|
||
}
|
||
|
||
func (c *ParallelConfig) SetIsContinueWhenError(isContinueWhenError bool) *ParallelConfig {
|
||
c.IsContinueWhenError = isContinueWhenError
|
||
return c
|
||
}
|
||
|
||
func (c *ParallelConfig) SetResultHandler(resultHandler ResultHandlerFunc) *ParallelConfig {
|
||
c.ResultHandler = resultHandler
|
||
return c
|
||
}
|
||
|
||
func NewParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *Task {
|
||
if config == nil {
|
||
config = NewParallelConfig()
|
||
}
|
||
if config.ParallelCount == 0 {
|
||
config.ParallelCount = DefParallelCount
|
||
}
|
||
if config.ParallelCount > MaxParallelCount {
|
||
config.ParallelCount = MaxParallelCount
|
||
}
|
||
realItemList := utils.Interface2Slice(itemList)
|
||
jobList := jxutils.SplitSlice(realItemList, config.BatchSize)
|
||
jobListLen := len(jobList)
|
||
if config.ParallelCount > jobListLen {
|
||
config.ParallelCount = jobListLen
|
||
}
|
||
task := &Task{
|
||
subFinishChan: make(chan interface{}, config.ParallelCount),
|
||
taskChan: make(chan []interface{}, len(realItemList)+config.ParallelCount), // 确保能装下所有taskitem,加结束标记
|
||
resultHandler: config.ResultHandler,
|
||
worker: worker,
|
||
jobList: jobList,
|
||
}
|
||
task.Init(config.ParallelCount, config.BatchSize, config.IsContinueWhenError, params, taskName, userName, len(realItemList), jobListLen)
|
||
return task
|
||
}
|
||
|
||
func RunParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *Task {
|
||
task := NewParallelTask(taskName, userName, config, worker, itemList, params...)
|
||
task.Run()
|
||
return task
|
||
}
|
||
|
||
func RunTask(taskName string, isContinueWhenError bool, resultHandler ResultHandlerFunc, parallelCount, batchSize int, userName string, worker WorkFunc, itemList interface{}, params ...interface{}) *Task {
|
||
config := NewParallelConfig()
|
||
config.BatchSize = batchSize
|
||
config.IsContinueWhenError = isContinueWhenError
|
||
config.ParallelCount = parallelCount
|
||
config.ResultHandler = resultHandler
|
||
task := NewParallelTask(taskName, userName, config, worker, itemList, params...)
|
||
task.Run()
|
||
return task
|
||
}
|
||
|
||
func (task *Task) Run() {
|
||
task.run(func() {
|
||
globals.SugarLogger.Debugf("ParallelTask.Run %s", task.Name)
|
||
for i := 0; i < task.ParallelCount; i++ {
|
||
go func() {
|
||
var chanRetVal interface{}
|
||
retVal := make([]interface{}, 0)
|
||
for {
|
||
select {
|
||
case <-task.quitChan: // 取消
|
||
goto end
|
||
case job := <-task.taskChan:
|
||
if job == nil { // 任务完成
|
||
chanRetVal = retVal
|
||
goto end
|
||
} else {
|
||
result, err := task.worker(job, task.params...)
|
||
// globals.SugarLogger.Debugf("ParallelTask.Run %s, after call worker result:%v, err:%v", task.Name, result, err)
|
||
task.finishedOneJob(len(job), err)
|
||
if err == nil {
|
||
if result != nil {
|
||
retVal = append(retVal, utils.Interface2Slice(result)...)
|
||
}
|
||
} else {
|
||
globals.SugarLogger.Infof("ParallelTask.Run %s, subtask(job:%s, params:%s) result:%v, failed with error:%v", task.Name, utils.Format4Output(job, true), utils.Format4Output(task.params, true), result, err)
|
||
if !task.IsContinueWhenError { // 出错
|
||
chanRetVal = err
|
||
goto end
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
end:
|
||
// globals.SugarLogger.Debugf("ParallelTask.Run %s, put to chann chanRetVal:%v", task.Name, chanRetVal)
|
||
task.locker.RLock()
|
||
if task.Status < TaskStatusEndBegin {
|
||
task.subFinishChan <- chanRetVal
|
||
}
|
||
task.locker.RUnlock()
|
||
}()
|
||
}
|
||
for _, job := range task.jobList {
|
||
task.taskChan <- job
|
||
}
|
||
for i := 0; i < task.ParallelCount; i++ {
|
||
task.taskChan <- nil
|
||
}
|
||
|
||
taskResult := make([]interface{}, 0)
|
||
var taskErr error
|
||
for i := 0; i < task.ParallelCount; i++ {
|
||
result := <-task.subFinishChan
|
||
// globals.SugarLogger.Debugf("ParallelTask.Run %s, received from chann result:%v", taskName, result)
|
||
if err, ok := result.(error); ok {
|
||
task.Cancel()
|
||
taskResult = nil
|
||
taskErr = err
|
||
break // 出错情况下是否需要直接跳出?
|
||
} else if result != nil {
|
||
resultList := result.([]interface{})
|
||
taskResult = append(taskResult, resultList...)
|
||
}
|
||
}
|
||
|
||
task.locker.Lock()
|
||
if taskErr != nil { // 如果有错误,肯定就是失败了
|
||
task.Status = TaskStatusFailed
|
||
} else {
|
||
if len(task.taskChan) > 0 {
|
||
taskErr = ErrTaskIsCanceled
|
||
task.Status = TaskStatusCanceled
|
||
} else {
|
||
task.Status = TaskStatusFinished
|
||
}
|
||
}
|
||
task.Err = taskErr
|
||
task.Result = taskResult
|
||
task.TerminatedAt = time.Now()
|
||
task.locker.Unlock()
|
||
|
||
globals.SugarLogger.Debugf("ParallelTask.Run %s, result:%v, err:%v", task.Name, taskResult, taskErr)
|
||
|
||
close(task.finishChan)
|
||
close(task.subFinishChan)
|
||
close(task.quitChan)
|
||
|
||
if task.resultHandler != nil {
|
||
task.resultHandler(task.Name, taskResult, task.Err)
|
||
}
|
||
})
|
||
}
|