io.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package buf
  2. import (
  3. "io"
  4. "net"
  5. "os"
  6. "syscall"
  7. "time"
  8. "github.com/xtls/xray-core/features/stats"
  9. "github.com/xtls/xray-core/transport/internet/stat"
  10. )
  11. // Reader extends io.Reader with MultiBuffer.
  12. type Reader interface {
  13. // ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
  14. ReadMultiBuffer() (MultiBuffer, error)
  15. }
  16. // ErrReadTimeout is an error that happens with IO timeout.
  17. var ErrReadTimeout = newError("IO timeout")
  18. // TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
  19. type TimeoutReader interface {
  20. ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
  21. }
  22. // Writer extends io.Writer with MultiBuffer.
  23. type Writer interface {
  24. // WriteMultiBuffer writes a MultiBuffer into underlying writer.
  25. WriteMultiBuffer(MultiBuffer) error
  26. }
  27. // WriteAllBytes ensures all bytes are written into the given writer.
  28. func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error {
  29. wc := 0
  30. defer func() {
  31. if c != nil {
  32. c.Add(int64(wc))
  33. }
  34. }()
  35. for len(payload) > 0 {
  36. n, err := writer.Write(payload)
  37. wc += n
  38. if err != nil {
  39. return err
  40. }
  41. payload = payload[n:]
  42. }
  43. return nil
  44. }
  45. func isPacketReader(reader io.Reader) bool {
  46. _, ok := reader.(net.PacketConn)
  47. return ok
  48. }
  49. // NewReader creates a new Reader.
  50. // The Reader instance doesn't take the ownership of reader.
  51. func NewReader(reader io.Reader) Reader {
  52. if mr, ok := reader.(Reader); ok {
  53. return mr
  54. }
  55. if isPacketReader(reader) {
  56. return &PacketReader{
  57. Reader: reader,
  58. }
  59. }
  60. _, isFile := reader.(*os.File)
  61. if !isFile && useReadv {
  62. if sc, ok := reader.(syscall.Conn); ok {
  63. rawConn, err := sc.SyscallConn()
  64. if err != nil {
  65. newError("failed to get sysconn").Base(err).WriteToLog()
  66. } else {
  67. var counter stats.Counter
  68. if statConn, ok := reader.(*stat.CounterConnection); ok {
  69. reader = statConn.Connection
  70. counter = statConn.ReadCounter
  71. }
  72. return NewReadVReader(reader, rawConn, counter)
  73. }
  74. }
  75. }
  76. return &SingleReader{
  77. Reader: reader,
  78. }
  79. }
  80. // NewPacketReader creates a new PacketReader based on the given reader.
  81. func NewPacketReader(reader io.Reader) Reader {
  82. if mr, ok := reader.(Reader); ok {
  83. return mr
  84. }
  85. return &PacketReader{
  86. Reader: reader,
  87. }
  88. }
  89. func isPacketWriter(writer io.Writer) bool {
  90. if _, ok := writer.(net.PacketConn); ok {
  91. return true
  92. }
  93. // If the writer doesn't implement syscall.Conn, it is probably not a TCP connection.
  94. if _, ok := writer.(syscall.Conn); !ok {
  95. return true
  96. }
  97. return false
  98. }
  99. // NewWriter creates a new Writer.
  100. func NewWriter(writer io.Writer) Writer {
  101. if mw, ok := writer.(Writer); ok {
  102. return mw
  103. }
  104. iConn := writer
  105. if statConn, ok := writer.(*stat.CounterConnection); ok {
  106. iConn = statConn.Connection
  107. }
  108. if isPacketWriter(iConn) {
  109. return &SequentialWriter{
  110. Writer: writer,
  111. }
  112. }
  113. var counter stats.Counter
  114. if statConn, ok := writer.(*stat.CounterConnection); ok {
  115. counter = statConn.WriteCounter
  116. }
  117. return &BufferToBytesWriter{
  118. Writer: iConn,
  119. counter: counter,
  120. }
  121. }