upload_queue.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package splithttp
  2. // upload_queue is a specialized priorityqueue + channel to reorder generic
  3. // packets by a sequence number
  4. import (
  5. "container/heap"
  6. "io"
  7. "sync"
  8. "github.com/xtls/xray-core/common/errors"
  9. )
  10. type Packet struct {
  11. Payload []byte
  12. Seq uint64
  13. }
  14. type uploadQueue struct {
  15. pushedPackets chan Packet
  16. writeCloseMutex sync.Mutex
  17. heap uploadHeap
  18. nextSeq uint64
  19. closed bool
  20. maxPackets int
  21. }
  22. func NewUploadQueue(maxPackets int) *uploadQueue {
  23. return &uploadQueue{
  24. pushedPackets: make(chan Packet, maxPackets),
  25. heap: uploadHeap{},
  26. nextSeq: 0,
  27. closed: false,
  28. maxPackets: maxPackets,
  29. }
  30. }
  31. func (h *uploadQueue) Push(p Packet) error {
  32. h.writeCloseMutex.Lock()
  33. defer h.writeCloseMutex.Unlock()
  34. if h.closed {
  35. return errors.New("splithttp packet queue closed")
  36. }
  37. h.pushedPackets <- p
  38. return nil
  39. }
  40. func (h *uploadQueue) Close() error {
  41. h.writeCloseMutex.Lock()
  42. defer h.writeCloseMutex.Unlock()
  43. h.closed = true
  44. close(h.pushedPackets)
  45. return nil
  46. }
  47. func (h *uploadQueue) Read(b []byte) (int, error) {
  48. if h.closed {
  49. return 0, io.EOF
  50. }
  51. if len(h.heap) == 0 {
  52. packet, more := <-h.pushedPackets
  53. if !more {
  54. return 0, io.EOF
  55. }
  56. heap.Push(&h.heap, packet)
  57. }
  58. for len(h.heap) > 0 {
  59. packet := heap.Pop(&h.heap).(Packet)
  60. n := 0
  61. if packet.Seq == h.nextSeq {
  62. copy(b, packet.Payload)
  63. n = min(len(b), len(packet.Payload))
  64. if n < len(packet.Payload) {
  65. // partial read
  66. packet.Payload = packet.Payload[n:]
  67. heap.Push(&h.heap, packet)
  68. } else {
  69. h.nextSeq = packet.Seq + 1
  70. }
  71. return n, nil
  72. }
  73. // misordered packet
  74. if packet.Seq > h.nextSeq {
  75. if len(h.heap) > h.maxPackets {
  76. // the "reassembly buffer" is too large, and we want to
  77. // constrain memory usage somehow. let's tear down the
  78. // connection, and hope the application retries.
  79. return 0, errors.New("packet queue is too large")
  80. }
  81. heap.Push(&h.heap, packet)
  82. packet2, more := <-h.pushedPackets
  83. if !more {
  84. return 0, io.EOF
  85. }
  86. heap.Push(&h.heap, packet2)
  87. }
  88. }
  89. return 0, nil
  90. }
  91. // heap code directly taken from https://pkg.go.dev/container/heap
  92. type uploadHeap []Packet
  93. func (h uploadHeap) Len() int { return len(h) }
  94. func (h uploadHeap) Less(i, j int) bool { return h[i].Seq < h[j].Seq }
  95. func (h uploadHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  96. func (h *uploadHeap) Push(x any) {
  97. // Push and Pop use pointer receivers because they modify the slice's length,
  98. // not just its contents.
  99. *h = append(*h, x.(Packet))
  100. }
  101. func (h *uploadHeap) Pop() any {
  102. old := *h
  103. n := len(old)
  104. x := old[n-1]
  105. *h = old[0 : n-1]
  106. return x
  107. }