1
0

hunkconn.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package encoding
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "github.com/xtls/xray-core/common/buf"
  7. "github.com/xtls/xray-core/common/errors"
  8. xnet "github.com/xtls/xray-core/common/net"
  9. "github.com/xtls/xray-core/common/net/cnc"
  10. "github.com/xtls/xray-core/common/signal/done"
  11. "google.golang.org/grpc/metadata"
  12. "google.golang.org/grpc/peer"
  13. )
  14. type HunkConn interface {
  15. Context() context.Context
  16. Send(*Hunk) error
  17. Recv() (*Hunk, error)
  18. SendMsg(m interface{}) error
  19. RecvMsg(m interface{}) error
  20. }
  21. type StreamCloser interface {
  22. CloseSend() error
  23. }
  24. type HunkReaderWriter struct {
  25. hc HunkConn
  26. cancel context.CancelFunc
  27. done *done.Instance
  28. buf []byte
  29. index int
  30. }
  31. func NewHunkReadWriter(hc HunkConn, cancel context.CancelFunc) *HunkReaderWriter {
  32. return &HunkReaderWriter{hc, cancel, done.New(), nil, 0}
  33. }
  34. func NewHunkConn(hc HunkConn, cancel context.CancelFunc) net.Conn {
  35. var rAddr net.Addr
  36. pr, ok := peer.FromContext(hc.Context())
  37. if ok {
  38. rAddr = pr.Addr
  39. } else {
  40. rAddr = &net.TCPAddr{
  41. IP: []byte{0, 0, 0, 0},
  42. Port: 0,
  43. }
  44. }
  45. md, ok := metadata.FromIncomingContext(hc.Context())
  46. if ok {
  47. header := md.Get("x-real-ip")
  48. if len(header) > 0 {
  49. realip := xnet.ParseAddress(header[0])
  50. if realip.Family().IsIP() {
  51. rAddr = &net.TCPAddr{
  52. IP: realip.IP(),
  53. Port: 0,
  54. }
  55. }
  56. }
  57. }
  58. wrc := NewHunkReadWriter(hc, cancel)
  59. return cnc.NewConnection(
  60. cnc.ConnectionInput(wrc),
  61. cnc.ConnectionOutput(wrc),
  62. cnc.ConnectionOnClose(wrc),
  63. cnc.ConnectionRemoteAddr(rAddr),
  64. )
  65. }
  66. func (h *HunkReaderWriter) forceFetch() error {
  67. hunk, err := h.hc.Recv()
  68. if err != nil {
  69. if err == io.EOF {
  70. return err
  71. }
  72. return errors.New("failed to fetch hunk from gRPC tunnel").Base(err)
  73. }
  74. h.buf = hunk.Data
  75. h.index = 0
  76. return nil
  77. }
  78. func (h *HunkReaderWriter) Read(buf []byte) (int, error) {
  79. if h.done.Done() {
  80. return 0, io.EOF
  81. }
  82. if h.index >= len(h.buf) {
  83. if err := h.forceFetch(); err != nil {
  84. return 0, err
  85. }
  86. }
  87. n := copy(buf, h.buf[h.index:])
  88. h.index += n
  89. return n, nil
  90. }
  91. func (h *HunkReaderWriter) ReadMultiBuffer() (buf.MultiBuffer, error) {
  92. if h.done.Done() {
  93. return nil, io.EOF
  94. }
  95. if h.index >= len(h.buf) {
  96. if err := h.forceFetch(); err != nil {
  97. return nil, err
  98. }
  99. }
  100. if cap(h.buf) >= buf.Size {
  101. b := h.buf
  102. h.index = len(h.buf)
  103. return buf.MultiBuffer{buf.NewExisted(b)}, nil
  104. }
  105. b := buf.New()
  106. _, err := b.ReadFrom(h)
  107. if err != nil {
  108. return nil, err
  109. }
  110. return buf.MultiBuffer{b}, nil
  111. }
  112. func (h *HunkReaderWriter) Write(buf []byte) (int, error) {
  113. if h.done.Done() {
  114. return 0, io.ErrClosedPipe
  115. }
  116. err := h.hc.Send(&Hunk{Data: buf[:]})
  117. if err != nil {
  118. return 0, errors.New("failed to send data over gRPC tunnel").Base(err)
  119. }
  120. return len(buf), nil
  121. }
  122. func (h *HunkReaderWriter) Close() error {
  123. if h.cancel != nil {
  124. h.cancel()
  125. }
  126. if sc, match := h.hc.(StreamCloser); match {
  127. return sc.CloseSend()
  128. }
  129. return h.done.Close()
  130. }