io.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package buf
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "os"
  7. "syscall"
  8. "time"
  9. "github.com/xtls/xray-core/common/errors"
  10. "github.com/xtls/xray-core/features/stats"
  11. "github.com/xtls/xray-core/transport/internet/stat"
  12. )
  13. // Reader extends io.Reader with MultiBuffer.
  14. type Reader interface {
  15. // ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
  16. ReadMultiBuffer() (MultiBuffer, error)
  17. }
  18. // ErrReadTimeout is an error that happens with IO timeout.
  19. var ErrReadTimeout = errors.New("IO timeout")
  20. // TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
  21. type TimeoutReader interface {
  22. Reader
  23. ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
  24. }
  25. type TimeoutWrapperReader struct {
  26. Reader
  27. stats.Counter
  28. mb MultiBuffer
  29. err error
  30. done chan struct{}
  31. }
  32. func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) {
  33. if r.done != nil {
  34. <-r.done
  35. r.done = nil
  36. if r.Counter != nil {
  37. r.Counter.Add(int64(r.mb.Len()))
  38. }
  39. return r.mb, r.err
  40. }
  41. r.mb, r.err = r.Reader.ReadMultiBuffer()
  42. if r.Counter != nil {
  43. r.Counter.Add(int64(r.mb.Len()))
  44. }
  45. return r.mb, r.err
  46. }
  47. func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) {
  48. if r.done == nil {
  49. r.done = make(chan struct{})
  50. go func() {
  51. r.mb, r.err = r.Reader.ReadMultiBuffer()
  52. close(r.done)
  53. }()
  54. }
  55. timeout := make(chan struct{})
  56. go func() {
  57. time.Sleep(duration)
  58. close(timeout)
  59. }()
  60. select {
  61. case <-r.done:
  62. r.done = nil
  63. if r.Counter != nil {
  64. r.Counter.Add(int64(r.mb.Len()))
  65. }
  66. return r.mb, r.err
  67. case <-timeout:
  68. return nil, nil
  69. }
  70. }
  71. // Writer extends io.Writer with MultiBuffer.
  72. type Writer interface {
  73. // WriteMultiBuffer writes a MultiBuffer into underlying writer.
  74. WriteMultiBuffer(MultiBuffer) error
  75. }
  76. // WriteAllBytes ensures all bytes are written into the given writer.
  77. func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error {
  78. wc := 0
  79. defer func() {
  80. if c != nil {
  81. c.Add(int64(wc))
  82. }
  83. }()
  84. for len(payload) > 0 {
  85. n, err := writer.Write(payload)
  86. wc += n
  87. if err != nil {
  88. return err
  89. }
  90. payload = payload[n:]
  91. }
  92. return nil
  93. }
  94. func isPacketReader(reader io.Reader) bool {
  95. _, ok := reader.(net.PacketConn)
  96. return ok
  97. }
  98. // NewReader creates a new Reader.
  99. // The Reader instance doesn't take the ownership of reader.
  100. func NewReader(reader io.Reader) Reader {
  101. if mr, ok := reader.(Reader); ok {
  102. return mr
  103. }
  104. if isPacketReader(reader) {
  105. return &PacketReader{
  106. Reader: reader,
  107. }
  108. }
  109. _, isFile := reader.(*os.File)
  110. if !isFile && useReadv {
  111. if sc, ok := reader.(syscall.Conn); ok {
  112. rawConn, err := sc.SyscallConn()
  113. if err != nil {
  114. errors.LogInfoInner(context.Background(), err, "failed to get sysconn")
  115. } else {
  116. var counter stats.Counter
  117. if statConn, ok := reader.(*stat.CounterConnection); ok {
  118. reader = statConn.Connection
  119. counter = statConn.ReadCounter
  120. }
  121. return NewReadVReader(reader, rawConn, counter)
  122. }
  123. }
  124. }
  125. return &SingleReader{
  126. Reader: reader,
  127. }
  128. }
  129. // NewPacketReader creates a new PacketReader based on the given reader.
  130. func NewPacketReader(reader io.Reader) Reader {
  131. if mr, ok := reader.(Reader); ok {
  132. return mr
  133. }
  134. return &PacketReader{
  135. Reader: reader,
  136. }
  137. }
  138. func isPacketWriter(writer io.Writer) bool {
  139. if _, ok := writer.(net.PacketConn); ok {
  140. return true
  141. }
  142. // If the writer doesn't implement syscall.Conn, it is probably not a TCP connection.
  143. if _, ok := writer.(syscall.Conn); !ok {
  144. return true
  145. }
  146. return false
  147. }
  148. // NewWriter creates a new Writer.
  149. func NewWriter(writer io.Writer) Writer {
  150. if mw, ok := writer.(Writer); ok {
  151. return mw
  152. }
  153. iConn := writer
  154. if statConn, ok := writer.(*stat.CounterConnection); ok {
  155. iConn = statConn.Connection
  156. }
  157. if isPacketWriter(iConn) {
  158. return &SequentialWriter{
  159. Writer: writer,
  160. }
  161. }
  162. var counter stats.Counter
  163. if statConn, ok := writer.(*stat.CounterConnection); ok {
  164. counter = statConn.WriteCounter
  165. }
  166. return &BufferToBytesWriter{
  167. Writer: iConn,
  168. counter: counter,
  169. }
  170. }