hunkconn.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package encoding
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "google.golang.org/grpc/peer"
  7. "github.com/xtls/xray-core/common/buf"
  8. "github.com/xtls/xray-core/common/net/cnc"
  9. "github.com/xtls/xray-core/common/signal/done"
  10. )
  11. type HunkConn interface {
  12. Context() context.Context
  13. Send(*Hunk) error
  14. Recv() (*Hunk, error)
  15. SendMsg(m interface{}) error
  16. RecvMsg(m interface{}) error
  17. }
  18. type StreamCloser interface {
  19. CloseSend() error
  20. }
  21. type HunkReaderWriter struct {
  22. hc HunkConn
  23. cancel context.CancelFunc
  24. done *done.Instance
  25. buf []byte
  26. index int
  27. }
  28. func NewHunkReadWriter(hc HunkConn, cancel context.CancelFunc) *HunkReaderWriter {
  29. return &HunkReaderWriter{hc, cancel, done.New(), nil, 0}
  30. }
  31. func NewHunkConn(hc HunkConn, cancel context.CancelFunc) net.Conn {
  32. var rAddr net.Addr
  33. pr, ok := peer.FromContext(hc.Context())
  34. if ok {
  35. rAddr = pr.Addr
  36. } else {
  37. rAddr = &net.TCPAddr{
  38. IP: []byte{0, 0, 0, 0},
  39. Port: 0,
  40. }
  41. }
  42. wrc := NewHunkReadWriter(hc, cancel)
  43. return cnc.NewConnection(
  44. cnc.ConnectionInput(wrc),
  45. cnc.ConnectionOutput(wrc),
  46. cnc.ConnectionOnClose(wrc),
  47. cnc.ConnectionRemoteAddr(rAddr),
  48. )
  49. }
  50. func (h *HunkReaderWriter) forceFetch() error {
  51. hunk, err := h.hc.Recv()
  52. if err != nil {
  53. if err == io.EOF {
  54. return err
  55. }
  56. return newError("failed to fetch hunk from gRPC tunnel").Base(err)
  57. }
  58. h.buf = hunk.Data
  59. h.index = 0
  60. return nil
  61. }
  62. func (h *HunkReaderWriter) Read(buf []byte) (int, error) {
  63. if h.done.Done() {
  64. return 0, io.EOF
  65. }
  66. if h.index >= len(h.buf) {
  67. if err := h.forceFetch(); err != nil {
  68. return 0, err
  69. }
  70. }
  71. n := copy(buf, h.buf[h.index:])
  72. h.index += n
  73. return n, nil
  74. }
  75. func (h *HunkReaderWriter) ReadMultiBuffer() (buf.MultiBuffer, error) {
  76. if h.done.Done() {
  77. return nil, io.EOF
  78. }
  79. if h.index >= len(h.buf) {
  80. if err := h.forceFetch(); err != nil {
  81. return nil, err
  82. }
  83. }
  84. if cap(h.buf) >= buf.Size {
  85. b := h.buf
  86. h.index = len(h.buf)
  87. return buf.MultiBuffer{buf.NewExisted(b)}, nil
  88. }
  89. b := buf.New()
  90. _, err := b.ReadFrom(h)
  91. if err != nil {
  92. return nil, err
  93. }
  94. return buf.MultiBuffer{b}, nil
  95. }
  96. func (h *HunkReaderWriter) Write(buf []byte) (int, error) {
  97. if h.done.Done() {
  98. return 0, io.ErrClosedPipe
  99. }
  100. err := h.hc.Send(&Hunk{Data: buf[:]})
  101. if err != nil {
  102. return 0, newError("failed to send data over gRPC tunnel").Base(err)
  103. }
  104. return len(buf), nil
  105. }
  106. func (h *HunkReaderWriter) Close() error {
  107. if h.cancel != nil {
  108. h.cancel()
  109. }
  110. if sc, match := h.hc.(StreamCloser); match {
  111. return sc.CloseSend()
  112. }
  113. return h.done.Close()
  114. }