xudp.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package xudp
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "os"
  9. "strings"
  10. "github.com/xtls/xray-core/common/buf"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/protocol"
  13. "github.com/xtls/xray-core/common/session"
  14. "lukechampine.com/blake3"
  15. )
  16. var AddrParser = protocol.NewAddressParser(
  17. protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
  18. protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
  19. protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
  20. protocol.PortThenAddress(),
  21. )
  22. var (
  23. Show bool
  24. BaseKey []byte
  25. )
  26. const (
  27. EnvShow = "XRAY_XUDP_SHOW"
  28. EnvBaseKey = "XRAY_XUDP_BASEKEY"
  29. )
  30. func init() {
  31. if strings.ToLower(os.Getenv(EnvShow)) == "true" {
  32. Show = true
  33. }
  34. if raw, found := os.LookupEnv(EnvBaseKey); found {
  35. if BaseKey, _ = base64.RawURLEncoding.DecodeString(raw); len(BaseKey) == 32 {
  36. return
  37. }
  38. panic(EnvBaseKey + ": invalid value: " + raw)
  39. }
  40. rand.Read(BaseKey)
  41. }
  42. func GetGlobalID(ctx context.Context) (globalID [8]byte) {
  43. if cone := ctx.Value("cone"); cone == nil || !cone.(bool) { // cone is nil only in some unit tests
  44. return
  45. }
  46. if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
  47. (inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
  48. h := blake3.New(8, BaseKey)
  49. h.Write([]byte(inbound.Source.String()))
  50. copy(globalID[:], h.Sum(nil))
  51. if Show {
  52. fmt.Printf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID)
  53. }
  54. }
  55. return
  56. }
  57. func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter {
  58. return &PacketWriter{
  59. Writer: writer,
  60. Dest: dest,
  61. GlobalID: globalID,
  62. }
  63. }
  64. type PacketWriter struct {
  65. Writer buf.Writer
  66. Dest net.Destination
  67. GlobalID [8]byte
  68. }
  69. func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  70. defer buf.ReleaseMulti(mb)
  71. mb2Write := make(buf.MultiBuffer, 0, len(mb))
  72. for _, b := range mb {
  73. length := b.Len()
  74. if length == 0 || length+666 > buf.Size {
  75. continue
  76. }
  77. eb := buf.New()
  78. eb.Write([]byte{0, 0, 0, 0}) // Meta data length; Mux Session ID
  79. if w.Dest.Network == net.Network_UDP {
  80. eb.WriteByte(1) // New
  81. eb.WriteByte(1) // Opt
  82. eb.WriteByte(2) // UDP
  83. AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
  84. if b.UDP != nil { // make sure it's user's proxy request
  85. eb.Write(w.GlobalID[:]) // no need to check whether it's empty
  86. }
  87. w.Dest.Network = net.Network_Unknown
  88. } else {
  89. eb.WriteByte(2) // Keep
  90. eb.WriteByte(1) // Opt
  91. if b.UDP != nil {
  92. eb.WriteByte(2) // UDP
  93. AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
  94. }
  95. }
  96. l := eb.Len() - 2
  97. eb.SetByte(0, byte(l>>8))
  98. eb.SetByte(1, byte(l))
  99. eb.WriteByte(byte(length >> 8))
  100. eb.WriteByte(byte(length))
  101. eb.Write(b.Bytes())
  102. mb2Write = append(mb2Write, eb)
  103. }
  104. if mb2Write.IsEmpty() {
  105. return nil
  106. }
  107. return w.Writer.WriteMultiBuffer(mb2Write)
  108. }
  109. func NewPacketReader(reader io.Reader) *PacketReader {
  110. return &PacketReader{
  111. Reader: reader,
  112. cache: make([]byte, 2),
  113. }
  114. }
  115. type PacketReader struct {
  116. Reader io.Reader
  117. cache []byte
  118. }
  119. func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  120. for {
  121. if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
  122. return nil, err
  123. }
  124. l := int32(r.cache[0])<<8 | int32(r.cache[1])
  125. if l < 4 {
  126. return nil, io.EOF
  127. }
  128. b := buf.New()
  129. if _, err := b.ReadFullFrom(r.Reader, l); err != nil {
  130. b.Release()
  131. return nil, err
  132. }
  133. discard := false
  134. switch b.Byte(2) {
  135. case 2:
  136. if l > 4 && b.Byte(4) == 2 { // MUST check the flag first
  137. b.Advance(5)
  138. // b.Clear() will be called automatically if all data had been read.
  139. addr, port, err := AddrParser.ReadAddressPort(nil, b)
  140. if err != nil {
  141. b.Release()
  142. return nil, err
  143. }
  144. b.UDP = &net.Destination{
  145. Network: net.Network_UDP,
  146. Address: addr,
  147. Port: port,
  148. }
  149. }
  150. case 4:
  151. discard = true
  152. default:
  153. b.Release()
  154. return nil, io.EOF
  155. }
  156. b.Clear() // in case there is padding (empty bytes) attached
  157. if b.Byte(3) == 1 {
  158. if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
  159. b.Release()
  160. return nil, err
  161. }
  162. length := int32(r.cache[0])<<8 | int32(r.cache[1])
  163. if length > 0 {
  164. if _, err := b.ReadFullFrom(r.Reader, length); err != nil {
  165. b.Release()
  166. return nil, err
  167. }
  168. if !discard {
  169. return buf.MultiBuffer{b}, nil
  170. }
  171. }
  172. }
  173. b.Release()
  174. }
  175. }