upload_queue.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package splithttp
  2. // upload_queue is a specialized priorityqueue + channel to reorder generic
  3. // packets by a sequence number
  4. import (
  5. "container/heap"
  6. "io"
  7. "runtime"
  8. "sync"
  9. "github.com/xtls/xray-core/common/errors"
  10. )
  11. type Packet struct {
  12. Reader io.ReadCloser
  13. Payload []byte
  14. Seq uint64
  15. }
  16. type uploadQueue struct {
  17. reader io.ReadCloser
  18. nomore bool
  19. pushedPackets chan Packet
  20. writeCloseMutex sync.Mutex
  21. heap uploadHeap
  22. nextSeq uint64
  23. closed bool
  24. maxPackets int
  25. }
  26. func NewUploadQueue(maxPackets int) *uploadQueue {
  27. return &uploadQueue{
  28. pushedPackets: make(chan Packet, maxPackets),
  29. heap: uploadHeap{},
  30. nextSeq: 0,
  31. closed: false,
  32. maxPackets: maxPackets,
  33. }
  34. }
  35. func (h *uploadQueue) Push(p Packet) error {
  36. h.writeCloseMutex.Lock()
  37. defer h.writeCloseMutex.Unlock()
  38. if h.closed {
  39. return errors.New("packet queue closed")
  40. }
  41. if h.nomore {
  42. return errors.New("h.reader already exists")
  43. }
  44. if p.Reader != nil {
  45. h.nomore = true
  46. }
  47. h.pushedPackets <- p
  48. return nil
  49. }
  50. func (h *uploadQueue) Close() error {
  51. h.writeCloseMutex.Lock()
  52. defer h.writeCloseMutex.Unlock()
  53. if !h.closed {
  54. h.closed = true
  55. runtime.Gosched() // hope Read() gets the packet
  56. f:
  57. for {
  58. select {
  59. case p := <-h.pushedPackets:
  60. if p.Reader != nil {
  61. h.reader = p.Reader
  62. }
  63. default:
  64. break f
  65. }
  66. }
  67. close(h.pushedPackets)
  68. }
  69. if h.reader != nil {
  70. return h.reader.Close()
  71. }
  72. return nil
  73. }
  74. func (h *uploadQueue) Read(b []byte) (int, error) {
  75. if h.reader != nil {
  76. return h.reader.Read(b)
  77. }
  78. if h.closed {
  79. return 0, io.EOF
  80. }
  81. if len(h.heap) == 0 {
  82. packet, more := <-h.pushedPackets
  83. if !more {
  84. return 0, io.EOF
  85. }
  86. if packet.Reader != nil {
  87. h.reader = packet.Reader
  88. return h.reader.Read(b)
  89. }
  90. heap.Push(&h.heap, packet)
  91. }
  92. for len(h.heap) > 0 {
  93. packet := heap.Pop(&h.heap).(Packet)
  94. n := 0
  95. if packet.Seq == h.nextSeq {
  96. copy(b, packet.Payload)
  97. n = min(len(b), len(packet.Payload))
  98. if n < len(packet.Payload) {
  99. // partial read
  100. packet.Payload = packet.Payload[n:]
  101. heap.Push(&h.heap, packet)
  102. } else {
  103. h.nextSeq = packet.Seq + 1
  104. }
  105. return n, nil
  106. }
  107. // misordered packet
  108. if packet.Seq > h.nextSeq {
  109. if len(h.heap) > h.maxPackets {
  110. // the "reassembly buffer" is too large, and we want to
  111. // constrain memory usage somehow. let's tear down the
  112. // connection, and hope the application retries.
  113. return 0, errors.New("packet queue is too large")
  114. }
  115. heap.Push(&h.heap, packet)
  116. packet2, more := <-h.pushedPackets
  117. if !more {
  118. return 0, io.EOF
  119. }
  120. heap.Push(&h.heap, packet2)
  121. }
  122. }
  123. return 0, nil
  124. }
  125. // heap code directly taken from https://pkg.go.dev/container/heap
  126. type uploadHeap []Packet
  127. func (h uploadHeap) Len() int { return len(h) }
  128. func (h uploadHeap) Less(i, j int) bool { return h[i].Seq < h[j].Seq }
  129. func (h uploadHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  130. func (h *uploadHeap) Push(x any) {
  131. // Push and Pop use pointer receivers because they modify the slice's length,
  132. // not just its contents.
  133. *h = append(*h, x.(Packet))
  134. }
  135. func (h *uploadHeap) Pop() any {
  136. old := *h
  137. n := len(old)
  138. x := old[n-1]
  139. *h = old[0 : n-1]
  140. return x
  141. }