multiconn.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package encoding
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "github.com/xtls/xray-core/common/buf"
  7. "github.com/xtls/xray-core/common/errors"
  8. xnet "github.com/xtls/xray-core/common/net"
  9. "github.com/xtls/xray-core/common/net/cnc"
  10. "github.com/xtls/xray-core/common/signal/done"
  11. "google.golang.org/grpc/metadata"
  12. "google.golang.org/grpc/peer"
  13. )
  14. type MultiHunkConn interface {
  15. Context() context.Context
  16. Send(*MultiHunk) error
  17. Recv() (*MultiHunk, error)
  18. SendMsg(m interface{}) error
  19. RecvMsg(m interface{}) error
  20. }
  21. type MultiHunkReaderWriter struct {
  22. hc MultiHunkConn
  23. cancel context.CancelFunc
  24. done *done.Instance
  25. buf [][]byte
  26. }
  27. func NewMultiHunkReadWriter(hc MultiHunkConn, cancel context.CancelFunc) *MultiHunkReaderWriter {
  28. return &MultiHunkReaderWriter{hc, cancel, done.New(), nil}
  29. }
  30. func NewMultiHunkConn(hc MultiHunkConn, cancel context.CancelFunc) net.Conn {
  31. var rAddr net.Addr
  32. pr, ok := peer.FromContext(hc.Context())
  33. if ok {
  34. rAddr = pr.Addr
  35. } else {
  36. rAddr = &net.TCPAddr{
  37. IP: []byte{0, 0, 0, 0},
  38. Port: 0,
  39. }
  40. }
  41. md, ok := metadata.FromIncomingContext(hc.Context())
  42. if ok {
  43. header := md.Get("x-real-ip")
  44. if len(header) > 0 {
  45. realip := xnet.ParseAddress(header[0])
  46. if realip.Family().IsIP() {
  47. rAddr = &net.TCPAddr{
  48. IP: realip.IP(),
  49. Port: 0,
  50. }
  51. }
  52. }
  53. }
  54. wrc := NewMultiHunkReadWriter(hc, cancel)
  55. return cnc.NewConnection(
  56. cnc.ConnectionInputMulti(wrc),
  57. cnc.ConnectionOutputMulti(wrc),
  58. cnc.ConnectionOnClose(wrc),
  59. cnc.ConnectionRemoteAddr(rAddr),
  60. )
  61. }
  62. func (h *MultiHunkReaderWriter) forceFetch() error {
  63. hunk, err := h.hc.Recv()
  64. if err != nil {
  65. if err == io.EOF {
  66. return err
  67. }
  68. return errors.New("failed to fetch hunk from gRPC tunnel").Base(err)
  69. }
  70. h.buf = hunk.Data
  71. return nil
  72. }
  73. func (h *MultiHunkReaderWriter) ReadMultiBuffer() (buf.MultiBuffer, error) {
  74. if h.done.Done() {
  75. return nil, io.EOF
  76. }
  77. if err := h.forceFetch(); err != nil {
  78. return nil, err
  79. }
  80. mb := make(buf.MultiBuffer, 0, len(h.buf))
  81. for _, b := range h.buf {
  82. if len(b) == 0 {
  83. continue
  84. }
  85. if cap(b) >= buf.Size {
  86. mb = append(mb, buf.NewExisted(b))
  87. } else {
  88. nb := buf.New()
  89. nb.Extend(int32(len(b)))
  90. copy(nb.Bytes(), b)
  91. mb = append(mb, nb)
  92. }
  93. }
  94. return mb, nil
  95. }
  96. func (h *MultiHunkReaderWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  97. defer buf.ReleaseMulti(mb)
  98. if h.done.Done() {
  99. return io.ErrClosedPipe
  100. }
  101. hunks := make([][]byte, 0, len(mb))
  102. for _, b := range mb {
  103. if b.Len() > 0 {
  104. hunks = append(hunks, b.Bytes())
  105. }
  106. }
  107. err := h.hc.Send(&MultiHunk{Data: hunks})
  108. if err != nil {
  109. return err
  110. }
  111. return nil
  112. }
  113. func (h *MultiHunkReaderWriter) Close() error {
  114. if h.cancel != nil {
  115. h.cancel()
  116. }
  117. if sc, match := h.hc.(StreamCloser); match {
  118. return sc.CloseSend()
  119. }
  120. return h.done.Close()
  121. }