task.go 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. package task
  2. import (
  3. "context"
  4. "github.com/xtls/xray-core/common/signal/semaphore"
  5. )
  6. // OnSuccess executes g() after f() returns nil.
  7. func OnSuccess(f func() error, g func() error) func() error {
  8. return func() error {
  9. if err := f(); err != nil {
  10. return err
  11. }
  12. return g()
  13. }
  14. }
  15. // Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
  16. func Run(ctx context.Context, tasks ...func() error) error {
  17. n := len(tasks)
  18. s := semaphore.New(n)
  19. done := make(chan error, 1)
  20. for _, task := range tasks {
  21. <-s.Wait()
  22. go func(f func() error) {
  23. err := f()
  24. if err == nil {
  25. s.Signal()
  26. return
  27. }
  28. select {
  29. case done <- err:
  30. default:
  31. }
  32. }(task)
  33. }
  34. /*
  35. if altctx := ctx.Value("altctx"); altctx != nil {
  36. ctx = altctx.(context.Context)
  37. }
  38. */
  39. for i := 0; i < n; i++ {
  40. select {
  41. case err := <-done:
  42. return err
  43. case <-ctx.Done():
  44. return ctx.Err()
  45. case <-s.Wait():
  46. }
  47. }
  48. /*
  49. if cancel := ctx.Value("cancel"); cancel != nil {
  50. cancel.(context.CancelFunc)()
  51. }
  52. */
  53. return nil
  54. }