addons.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package encoding
  2. import (
  3. "context"
  4. "io"
  5. "net"
  6. "github.com/xtls/xray-core/common/buf"
  7. "github.com/xtls/xray-core/common/errors"
  8. "github.com/xtls/xray-core/common/protocol"
  9. "github.com/xtls/xray-core/common/session"
  10. "github.com/xtls/xray-core/proxy"
  11. "github.com/xtls/xray-core/proxy/vless"
  12. "google.golang.org/protobuf/proto"
  13. )
  14. func EncodeHeaderAddons(buffer *buf.Buffer, addons *Addons) error {
  15. switch addons.Flow {
  16. case vless.XRV:
  17. bytes, err := proto.Marshal(addons)
  18. if err != nil {
  19. return errors.New("failed to marshal addons protobuf value").Base(err)
  20. }
  21. if err := buffer.WriteByte(byte(len(bytes))); err != nil {
  22. return errors.New("failed to write addons protobuf length").Base(err)
  23. }
  24. if _, err := buffer.Write(bytes); err != nil {
  25. return errors.New("failed to write addons protobuf value").Base(err)
  26. }
  27. default:
  28. if err := buffer.WriteByte(0); err != nil {
  29. return errors.New("failed to write addons protobuf length").Base(err)
  30. }
  31. }
  32. return nil
  33. }
  34. func DecodeHeaderAddons(buffer *buf.Buffer, reader io.Reader) (*Addons, error) {
  35. addons := new(Addons)
  36. buffer.Clear()
  37. if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
  38. return nil, errors.New("failed to read addons protobuf length").Base(err)
  39. }
  40. if length := int32(buffer.Byte(0)); length != 0 {
  41. buffer.Clear()
  42. if _, err := buffer.ReadFullFrom(reader, length); err != nil {
  43. return nil, errors.New("failed to read addons protobuf value").Base(err)
  44. }
  45. if err := proto.Unmarshal(buffer.Bytes(), addons); err != nil {
  46. return nil, errors.New("failed to unmarshal addons protobuf value").Base(err)
  47. }
  48. // Verification.
  49. switch addons.Flow {
  50. default:
  51. }
  52. }
  53. return addons, nil
  54. }
  55. // EncodeBodyAddons returns a Writer that auto-encrypt content written by caller.
  56. func EncodeBodyAddons(writer buf.Writer, request *protocol.RequestHeader, requestAddons *Addons, state *proxy.TrafficState, isUplink bool, context context.Context, conn net.Conn, ob *session.Outbound) buf.Writer {
  57. if request.Command == protocol.RequestCommandUDP {
  58. return NewMultiLengthPacketWriter(writer)
  59. }
  60. if requestAddons.Flow == vless.XRV {
  61. return proxy.NewVisionWriter(writer, state, isUplink, context, conn, ob, request.User.Account.(*vless.MemoryAccount).Testseed)
  62. }
  63. return writer
  64. }
  65. // DecodeBodyAddons returns a Reader from which caller can fetch decrypted body.
  66. func DecodeBodyAddons(reader io.Reader, request *protocol.RequestHeader, addons *Addons) buf.Reader {
  67. switch addons.Flow {
  68. default:
  69. if request.Command == protocol.RequestCommandUDP {
  70. return NewLengthPacketReader(reader)
  71. }
  72. }
  73. return buf.NewReader(reader)
  74. }
  75. func NewMultiLengthPacketWriter(writer buf.Writer) *MultiLengthPacketWriter {
  76. return &MultiLengthPacketWriter{
  77. Writer: writer,
  78. }
  79. }
  80. type MultiLengthPacketWriter struct {
  81. buf.Writer
  82. }
  83. func (w *MultiLengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  84. defer buf.ReleaseMulti(mb)
  85. mb2Write := make(buf.MultiBuffer, 0, len(mb)+1)
  86. for _, b := range mb {
  87. length := b.Len()
  88. if length == 0 || length+2 > buf.Size {
  89. continue
  90. }
  91. eb := buf.New()
  92. if err := eb.WriteByte(byte(length >> 8)); err != nil {
  93. eb.Release()
  94. continue
  95. }
  96. if err := eb.WriteByte(byte(length)); err != nil {
  97. eb.Release()
  98. continue
  99. }
  100. if _, err := eb.Write(b.Bytes()); err != nil {
  101. eb.Release()
  102. continue
  103. }
  104. mb2Write = append(mb2Write, eb)
  105. }
  106. if mb2Write.IsEmpty() {
  107. return nil
  108. }
  109. return w.Writer.WriteMultiBuffer(mb2Write)
  110. }
  111. func NewLengthPacketWriter(writer io.Writer) *LengthPacketWriter {
  112. return &LengthPacketWriter{
  113. Writer: writer,
  114. cache: make([]byte, 0, 65536),
  115. }
  116. }
  117. type LengthPacketWriter struct {
  118. io.Writer
  119. cache []byte
  120. }
  121. func (w *LengthPacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  122. length := mb.Len() // none of mb is nil
  123. // fmt.Println("Write", length)
  124. if length == 0 {
  125. return nil
  126. }
  127. defer func() {
  128. w.cache = w.cache[:0]
  129. }()
  130. w.cache = append(w.cache, byte(length>>8), byte(length))
  131. for i, b := range mb {
  132. w.cache = append(w.cache, b.Bytes()...)
  133. b.Release()
  134. mb[i] = nil
  135. }
  136. if _, err := w.Write(w.cache); err != nil {
  137. return errors.New("failed to write a packet").Base(err)
  138. }
  139. return nil
  140. }
  141. func NewLengthPacketReader(reader io.Reader) *LengthPacketReader {
  142. return &LengthPacketReader{
  143. Reader: reader,
  144. cache: make([]byte, 2),
  145. }
  146. }
  147. type LengthPacketReader struct {
  148. io.Reader
  149. cache []byte
  150. }
  151. func (r *LengthPacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  152. if _, err := io.ReadFull(r.Reader, r.cache); err != nil { // maybe EOF
  153. return nil, errors.New("failed to read packet length").Base(err)
  154. }
  155. length := int32(r.cache[0])<<8 | int32(r.cache[1])
  156. // fmt.Println("Read", length)
  157. mb := make(buf.MultiBuffer, 0, length/buf.Size+1)
  158. for length > 0 {
  159. size := length
  160. if size > buf.Size {
  161. size = buf.Size
  162. }
  163. length -= size
  164. b := buf.New()
  165. if _, err := b.ReadFullFrom(r.Reader, size); err != nil {
  166. return nil, errors.New("failed to read packet payload").Base(err)
  167. }
  168. mb = append(mb, b)
  169. }
  170. return mb, nil
  171. }