Files
jx-callback/business/jxutils/tasksch/parellel_task.go
2018-10-19 18:34:31 +08:00

202 lines
5.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tasksch
import (
"errors"
"time"
"git.rosy.net.cn/baseapi/utils"
"git.rosy.net.cn/jx-callback/business/jxutils"
"git.rosy.net.cn/jx-callback/globals"
)
const (
DefParallelCount = 10
MaxParallelCount = 10
)
type WorkFunc func(batchItemList []interface{}, params ...interface{}) (retVal interface{}, err error)
type ResultHandlerFunc func(taskName string, result []interface{}, err error)
type ParallelConfig struct {
ParallelCount int
BatchSize int
IsContinueWhenError bool
ResultHandler ResultHandlerFunc
}
type Task struct {
BaseTask
resultHandler ResultHandlerFunc
worker WorkFunc
jobList [][]interface{}
taskChan chan []interface{}
subFinishChan chan interface{}
}
var (
ErrTaskNotFinished = errors.New("任务还未完成")
ErrTaskIsCanceled = errors.New("任务被取消了")
)
func NewParallelConfig() *ParallelConfig {
return &ParallelConfig{
ParallelCount: DefParallelCount,
BatchSize: 1,
IsContinueWhenError: false,
ResultHandler: nil,
}
}
func NewParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *Task {
if config == nil {
config = NewParallelConfig()
}
if config.ParallelCount == 0 {
config.ParallelCount = DefParallelCount
}
if config.ParallelCount > MaxParallelCount {
config.ParallelCount = MaxParallelCount
}
realItemList := utils.Interface2Slice(itemList)
jobList := jxutils.SplitSlice(realItemList, config.BatchSize)
jobListLen := jxutils.GetSliceLen(jobList)
if config.ParallelCount > jobListLen {
config.ParallelCount = jobListLen
}
task := &Task{
BaseTask: BaseTask{
ParallelCount: config.ParallelCount,
BatchSize: config.BatchSize,
IsContinueWhenError: config.IsContinueWhenError,
params: params,
ID: utils.GetUUID(),
Name: taskName,
CreatedAt: time.Now(),
CreatedBy: userName,
UpdatedAt: time.Now(),
TotalJobCount: len(jobList),
TotalItemCount: len(realItemList),
quitChan: make(chan int, config.ParallelCount),
finishChan: make(chan int, 2),
Status: TaskStatusWorking,
},
subFinishChan: make(chan interface{}, config.ParallelCount),
taskChan: make(chan []interface{}, len(realItemList)+config.ParallelCount), // 确保能装下所有taskitem加结束标记
resultHandler: config.ResultHandler,
worker: worker,
jobList: jobList,
}
task.C = task.finishChan
return task
}
func RunParallelTask(taskName string, userName string, config *ParallelConfig, worker WorkFunc, itemList interface{}, params ...interface{}) *Task {
task := NewParallelTask(taskName, userName, config, worker, itemList, params...)
return task.Run()
}
func RunTask(taskName string, isContinueWhenError bool, resultHandler ResultHandlerFunc, parallelCount, batchSize int, userName string, worker WorkFunc, itemList interface{}, params ...interface{}) *Task {
config := NewParallelConfig()
config.BatchSize = batchSize
config.IsContinueWhenError = isContinueWhenError
config.ParallelCount = parallelCount
config.ResultHandler = resultHandler
task := NewParallelTask(taskName, userName, config, worker, itemList, params...)
task.Run()
return task
}
func (task *Task) Run() *Task {
go func() {
globals.SugarLogger.Debugf("Run ParallelTask %s", task.Name)
for i := 0; i < task.ParallelCount; i++ {
go func() {
var chanRetVal interface{}
retVal := make([]interface{}, 0)
for {
select {
case <-task.quitChan: // 取消
goto end
case job := <-task.taskChan:
if job == nil { // 任务完成
chanRetVal = retVal
goto end
} else {
result, err := task.worker(job, task.params...)
globals.SugarLogger.Debugf("RunTask %s, after call worker result:%v, err:%v", task.Name, result, err)
task.finishedOneJob(len(job), err)
if err == nil {
if result != nil {
retVal = append(retVal, utils.Interface2Slice(result)...)
}
} else if !task.IsContinueWhenError { // 出错
chanRetVal = err
goto end
}
}
}
}
end:
globals.SugarLogger.Debugf("RunTask %s, put to chann chanRetVal:%v", task.Name, chanRetVal)
task.locker.RLock()
if task.Status < TaskStatusEndBegin {
task.subFinishChan <- chanRetVal
}
task.locker.RUnlock()
}()
}
for _, job := range task.jobList {
task.taskChan <- job
}
for i := 0; i < task.ParallelCount; i++ {
task.taskChan <- nil
}
taskResult := make([]interface{}, 0)
var taskErr error
for i := 0; i < task.ParallelCount; i++ {
result := <-task.subFinishChan
// globals.SugarLogger.Debugf("RunTask %s, received from chann result:%v", taskName, result)
if err2, ok := result.(error); ok {
task.Cancel()
taskResult = nil
taskErr = err2
break // 出错情况下是否需要直接跳出?
} else if result != nil {
resultList := result.([]interface{})
taskResult = append(taskResult, resultList...)
}
}
task.locker.Lock()
if taskErr != nil { // 如果有错误,肯定就是失败了
task.Status = TaskStatusFailed
} else {
if len(task.taskChan) > 0 {
taskErr = ErrTaskIsCanceled
task.Status = TaskStatusCanceled
} else {
task.Status = TaskStatusFinished
}
}
task.err = taskErr
task.result = taskResult
task.TerminatedAt = time.Now()
task.locker.Unlock()
globals.SugarLogger.Debugf("RunTask %s, result:%v, err:%v", task.Name, taskResult, taskErr)
close(task.finishChan)
close(task.subFinishChan)
close(task.quitChan)
if task.resultHandler != nil {
task.resultHandler(task.Name, taskResult, task.err)
}
}()
return task
}