multiconn.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package encoding
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "google.golang.org/grpc/peer"
  7. "github.com/xtls/xray-core/common/buf"
  8. "github.com/xtls/xray-core/common/net/cnc"
  9. "github.com/xtls/xray-core/common/signal/done"
  10. )
  11. type MultiHunkConn interface {
  12. Context() context.Context
  13. Send(*MultiHunk) error
  14. Recv() (*MultiHunk, error)
  15. SendMsg(m interface{}) error
  16. RecvMsg(m interface{}) error
  17. }
  18. type MultiHunkReaderWriter struct {
  19. hc MultiHunkConn
  20. cancel context.CancelFunc
  21. done *done.Instance
  22. buf [][]byte
  23. }
  24. func NewMultiHunkReadWriter(hc MultiHunkConn, cancel context.CancelFunc) *MultiHunkReaderWriter {
  25. return &MultiHunkReaderWriter{hc, cancel, done.New(), nil}
  26. }
  27. func NewMultiHunkConn(hc MultiHunkConn, cancel context.CancelFunc) net.Conn {
  28. var rAddr net.Addr
  29. pr, ok := peer.FromContext(hc.Context())
  30. if ok {
  31. rAddr = pr.Addr
  32. } else {
  33. rAddr = &net.TCPAddr{
  34. IP: []byte{0, 0, 0, 0},
  35. Port: 0,
  36. }
  37. }
  38. wrc := NewMultiHunkReadWriter(hc, cancel)
  39. return cnc.NewConnection(
  40. cnc.ConnectionInputMulti(wrc),
  41. cnc.ConnectionOutputMulti(wrc),
  42. cnc.ConnectionOnClose(wrc),
  43. cnc.ConnectionRemoteAddr(rAddr),
  44. )
  45. }
  46. func (h *MultiHunkReaderWriter) forceFetch() error {
  47. hunk, err := h.hc.Recv()
  48. if err != nil {
  49. if err == io.EOF {
  50. return err
  51. }
  52. return newError("failed to fetch hunk from gRPC tunnel").Base(err)
  53. }
  54. h.buf = hunk.Data
  55. return nil
  56. }
  57. func (h *MultiHunkReaderWriter) ReadMultiBuffer() (buf.MultiBuffer, error) {
  58. if h.done.Done() {
  59. return nil, io.EOF
  60. }
  61. if err := h.forceFetch(); err != nil {
  62. return nil, err
  63. }
  64. var mb = make(buf.MultiBuffer, 0, len(h.buf))
  65. for _, b := range h.buf {
  66. if len(b) == 0 {
  67. continue
  68. }
  69. if cap(b) >= buf.Size {
  70. mb = append(mb, buf.NewExisted(b))
  71. } else {
  72. nb := buf.New()
  73. nb.Extend(int32(len(b)))
  74. copy(nb.Bytes(), b)
  75. mb = append(mb, nb)
  76. }
  77. }
  78. return mb, nil
  79. }
  80. func (h *MultiHunkReaderWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  81. defer buf.ReleaseMulti(mb)
  82. if h.done.Done() {
  83. return io.ErrClosedPipe
  84. }
  85. hunks := make([][]byte, 0, len(mb))
  86. for _, b := range mb {
  87. if b.Len() > 0 {
  88. hunks = append(hunks, b.Bytes())
  89. }
  90. }
  91. err := h.hc.Send(&MultiHunk{Data: hunks})
  92. if err != nil {
  93. return err
  94. }
  95. return nil
  96. }
  97. func (h *MultiHunkReaderWriter) Close() error {
  98. if h.cancel != nil {
  99. h.cancel()
  100. }
  101. if sc, match := h.hc.(StreamCloser); match {
  102. return sc.CloseSend()
  103. }
  104. return h.done.Close()
  105. }