writer.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package buf
  2. import (
  3. "io"
  4. "net"
  5. "sync"
  6. "github.com/xtls/xray-core/common"
  7. "github.com/xtls/xray-core/common/errors"
  8. )
  9. // BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
  10. type BufferToBytesWriter struct {
  11. io.Writer
  12. cache [][]byte
  13. }
  14. // WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
  15. func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
  16. defer ReleaseMulti(mb)
  17. size := mb.Len()
  18. if size == 0 {
  19. return nil
  20. }
  21. if len(mb) == 1 {
  22. return WriteAllBytes(w.Writer, mb[0].Bytes())
  23. }
  24. if cap(w.cache) < len(mb) {
  25. w.cache = make([][]byte, 0, len(mb))
  26. }
  27. bs := w.cache
  28. for _, b := range mb {
  29. bs = append(bs, b.Bytes())
  30. }
  31. defer func() {
  32. for idx := range bs {
  33. bs[idx] = nil
  34. }
  35. }()
  36. nb := net.Buffers(bs)
  37. for size > 0 {
  38. n, err := nb.WriteTo(w.Writer)
  39. if err != nil {
  40. return err
  41. }
  42. size -= int32(n)
  43. }
  44. return nil
  45. }
  46. // ReadFrom implements io.ReaderFrom.
  47. func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
  48. var sc SizeCounter
  49. err := Copy(NewReader(reader), w, CountSize(&sc))
  50. return sc.Size, err
  51. }
  52. // BufferedWriter is a Writer with internal buffer.
  53. type BufferedWriter struct {
  54. sync.Mutex
  55. writer Writer
  56. buffer *Buffer
  57. buffered bool
  58. }
  59. // NewBufferedWriter creates a new BufferedWriter.
  60. func NewBufferedWriter(writer Writer) *BufferedWriter {
  61. return &BufferedWriter{
  62. writer: writer,
  63. buffer: New(),
  64. buffered: true,
  65. }
  66. }
  67. // WriteByte implements io.ByteWriter.
  68. func (w *BufferedWriter) WriteByte(c byte) error {
  69. return common.Error2(w.Write([]byte{c}))
  70. }
  71. // Write implements io.Writer.
  72. func (w *BufferedWriter) Write(b []byte) (int, error) {
  73. if len(b) == 0 {
  74. return 0, nil
  75. }
  76. w.Lock()
  77. defer w.Unlock()
  78. if !w.buffered {
  79. if writer, ok := w.writer.(io.Writer); ok {
  80. return writer.Write(b)
  81. }
  82. }
  83. totalBytes := 0
  84. for len(b) > 0 {
  85. if w.buffer == nil {
  86. w.buffer = New()
  87. }
  88. nBytes, err := w.buffer.Write(b)
  89. totalBytes += nBytes
  90. if err != nil {
  91. return totalBytes, err
  92. }
  93. if !w.buffered || w.buffer.IsFull() {
  94. if err := w.flushInternal(); err != nil {
  95. return totalBytes, err
  96. }
  97. }
  98. b = b[nBytes:]
  99. }
  100. return totalBytes, nil
  101. }
  102. // WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer.
  103. func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
  104. if b.IsEmpty() {
  105. return nil
  106. }
  107. w.Lock()
  108. defer w.Unlock()
  109. if !w.buffered {
  110. return w.writer.WriteMultiBuffer(b)
  111. }
  112. reader := MultiBufferContainer{
  113. MultiBuffer: b,
  114. }
  115. defer reader.Close()
  116. for !reader.MultiBuffer.IsEmpty() {
  117. if w.buffer == nil {
  118. w.buffer = New()
  119. }
  120. common.Must2(w.buffer.ReadFrom(&reader))
  121. if w.buffer.IsFull() {
  122. if err := w.flushInternal(); err != nil {
  123. return err
  124. }
  125. }
  126. }
  127. return nil
  128. }
  129. // Flush flushes buffered content into underlying writer.
  130. func (w *BufferedWriter) Flush() error {
  131. w.Lock()
  132. defer w.Unlock()
  133. return w.flushInternal()
  134. }
  135. func (w *BufferedWriter) flushInternal() error {
  136. if w.buffer.IsEmpty() {
  137. return nil
  138. }
  139. b := w.buffer
  140. w.buffer = nil
  141. if writer, ok := w.writer.(io.Writer); ok {
  142. err := WriteAllBytes(writer, b.Bytes())
  143. b.Release()
  144. return err
  145. }
  146. return w.writer.WriteMultiBuffer(MultiBuffer{b})
  147. }
  148. // SetBuffered sets whether the internal buffer is used. If set to false, Flush() will be called to clear the buffer.
  149. func (w *BufferedWriter) SetBuffered(f bool) error {
  150. w.Lock()
  151. defer w.Unlock()
  152. w.buffered = f
  153. if !f {
  154. return w.flushInternal()
  155. }
  156. return nil
  157. }
  158. // ReadFrom implements io.ReaderFrom.
  159. func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
  160. if err := w.SetBuffered(false); err != nil {
  161. return 0, err
  162. }
  163. var sc SizeCounter
  164. err := Copy(NewReader(reader), w, CountSize(&sc))
  165. return sc.Size, err
  166. }
  167. // Close implements io.Closable.
  168. func (w *BufferedWriter) Close() error {
  169. if err := w.Flush(); err != nil {
  170. return err
  171. }
  172. return common.Close(w.writer)
  173. }
  174. // SequentialWriter is a Writer that writes MultiBuffer sequentially into the underlying io.Writer.
  175. type SequentialWriter struct {
  176. io.Writer
  177. }
  178. // WriteMultiBuffer implements Writer.
  179. func (w *SequentialWriter) WriteMultiBuffer(mb MultiBuffer) error {
  180. mb, err := WriteMultiBuffer(w.Writer, mb)
  181. ReleaseMulti(mb)
  182. return err
  183. }
  184. type noOpWriter byte
  185. func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
  186. ReleaseMulti(b)
  187. return nil
  188. }
  189. func (noOpWriter) Write(b []byte) (int, error) {
  190. return len(b), nil
  191. }
  192. func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
  193. b := New()
  194. defer b.Release()
  195. totalBytes := int64(0)
  196. for {
  197. b.Clear()
  198. _, err := b.ReadFrom(reader)
  199. totalBytes += int64(b.Len())
  200. if err != nil {
  201. if errors.Cause(err) == io.EOF {
  202. return totalBytes, nil
  203. }
  204. return totalBytes, err
  205. }
  206. }
  207. }
  208. var (
  209. // Discard is a Writer that swallows all contents written in.
  210. Discard Writer = noOpWriter(0)
  211. // DiscardBytes is an io.Writer that swallows all contents written in.
  212. DiscardBytes io.Writer = noOpWriter(0)
  213. )