copy.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package buf
  2. import (
  3. "io"
  4. "time"
  5. "github.com/xtls/xray-core/common/errors"
  6. "github.com/xtls/xray-core/common/signal"
  7. "github.com/xtls/xray-core/features/stats"
  8. )
  9. type dataHandler func(MultiBuffer)
  10. type copyHandler struct {
  11. onData []dataHandler
  12. }
  13. // SizeCounter is for counting bytes copied by Copy().
  14. type SizeCounter struct {
  15. Size int64
  16. }
  17. // CopyOption is an option for copying data.
  18. type CopyOption func(*copyHandler)
  19. // UpdateActivity is a CopyOption to update activity on each data copy operation.
  20. func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
  21. return func(handler *copyHandler) {
  22. handler.onData = append(handler.onData, func(MultiBuffer) {
  23. timer.Update()
  24. })
  25. }
  26. }
  27. // CountSize is a CopyOption that sums the total size of data copied into the given SizeCounter.
  28. func CountSize(sc *SizeCounter) CopyOption {
  29. return func(handler *copyHandler) {
  30. handler.onData = append(handler.onData, func(b MultiBuffer) {
  31. sc.Size += int64(b.Len())
  32. })
  33. }
  34. }
  35. // AddToStatCounter a CopyOption add to stat counter
  36. func AddToStatCounter(sc stats.Counter) CopyOption {
  37. return func(handler *copyHandler) {
  38. handler.onData = append(handler.onData, func(b MultiBuffer) {
  39. if sc != nil {
  40. sc.Add(int64(b.Len()))
  41. }
  42. })
  43. }
  44. }
  45. type readError struct {
  46. error
  47. }
  48. func (e readError) Error() string {
  49. return e.error.Error()
  50. }
  51. func (e readError) Unwrap() error {
  52. return e.error
  53. }
  54. // IsReadError returns true if the error in Copy() comes from reading.
  55. func IsReadError(err error) bool {
  56. _, ok := err.(readError)
  57. return ok
  58. }
  59. type writeError struct {
  60. error
  61. }
  62. func (e writeError) Error() string {
  63. return e.error.Error()
  64. }
  65. func (e writeError) Unwrap() error {
  66. return e.error
  67. }
  68. // IsWriteError returns true if the error in Copy() comes from writing.
  69. func IsWriteError(err error) bool {
  70. _, ok := err.(writeError)
  71. return ok
  72. }
  73. func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
  74. for {
  75. buffer, err := reader.ReadMultiBuffer()
  76. if !buffer.IsEmpty() {
  77. for _, handler := range handler.onData {
  78. handler(buffer)
  79. }
  80. if werr := writer.WriteMultiBuffer(buffer); werr != nil {
  81. return writeError{werr}
  82. }
  83. }
  84. if err != nil {
  85. return readError{err}
  86. }
  87. }
  88. }
  89. // Copy dumps all payload from reader to writer or stops when an error occurs. It returns nil when EOF.
  90. func Copy(reader Reader, writer Writer, options ...CopyOption) error {
  91. var handler copyHandler
  92. for _, option := range options {
  93. option(&handler)
  94. }
  95. err := copyInternal(reader, writer, &handler)
  96. if err != nil && errors.Cause(err) != io.EOF {
  97. return err
  98. }
  99. return nil
  100. }
  101. var ErrNotTimeoutReader = errors.New("not a TimeoutReader")
  102. func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error {
  103. timeoutReader, ok := reader.(TimeoutReader)
  104. if !ok {
  105. return ErrNotTimeoutReader
  106. }
  107. mb, err := timeoutReader.ReadMultiBufferTimeout(timeout)
  108. if err != nil {
  109. return err
  110. }
  111. return writer.WriteMultiBuffer(mb)
  112. }