xudp.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package xudp
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "strings"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/platform"
  12. "github.com/xtls/xray-core/common/protocol"
  13. "github.com/xtls/xray-core/common/session"
  14. "lukechampine.com/blake3"
  15. )
  16. var AddrParser = protocol.NewAddressParser(
  17. protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4),
  18. protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain),
  19. protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6),
  20. protocol.PortThenAddress(),
  21. )
  22. var (
  23. Show bool
  24. BaseKey []byte
  25. )
  26. func init() {
  27. if strings.ToLower(platform.NewEnvFlag(platform.XUDPLog).GetValue(func() string { return "" })) == "true" {
  28. Show = true
  29. }
  30. if raw := platform.NewEnvFlag(platform.XUDPBaseKey).GetValue(func() string { return "" }); raw != "" {
  31. if BaseKey, _ = base64.RawURLEncoding.DecodeString(raw); len(BaseKey) == 32 {
  32. return
  33. }
  34. panic(platform.XUDPBaseKey + ": invalid value: " + raw)
  35. }
  36. rand.Read(BaseKey)
  37. }
  38. func GetGlobalID(ctx context.Context) (globalID [8]byte) {
  39. if cone := ctx.Value("cone"); cone == nil || !cone.(bool) { // cone is nil only in some unit tests
  40. return
  41. }
  42. if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
  43. (inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
  44. h := blake3.New(8, BaseKey)
  45. h.Write([]byte(inbound.Source.String()))
  46. copy(globalID[:], h.Sum(nil))
  47. if Show {
  48. newError(fmt.Sprintf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID)).WriteToLog(session.ExportIDToError(ctx))
  49. }
  50. }
  51. return
  52. }
  53. func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter {
  54. return &PacketWriter{
  55. Writer: writer,
  56. Dest: dest,
  57. GlobalID: globalID,
  58. }
  59. }
  60. type PacketWriter struct {
  61. Writer buf.Writer
  62. Dest net.Destination
  63. GlobalID [8]byte
  64. }
  65. func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  66. defer buf.ReleaseMulti(mb)
  67. mb2Write := make(buf.MultiBuffer, 0, len(mb))
  68. for _, b := range mb {
  69. length := b.Len()
  70. if length == 0 || length+666 > buf.Size {
  71. continue
  72. }
  73. eb := buf.New()
  74. eb.Write([]byte{0, 0, 0, 0}) // Meta data length; Mux Session ID
  75. if w.Dest.Network == net.Network_UDP {
  76. eb.WriteByte(1) // New
  77. eb.WriteByte(1) // Opt
  78. eb.WriteByte(2) // UDP
  79. AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port)
  80. if b.UDP != nil { // make sure it's user's proxy request
  81. eb.Write(w.GlobalID[:]) // no need to check whether it's empty
  82. }
  83. w.Dest.Network = net.Network_Unknown
  84. } else {
  85. eb.WriteByte(2) // Keep
  86. eb.WriteByte(1) // Opt
  87. if b.UDP != nil {
  88. eb.WriteByte(2) // UDP
  89. AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port)
  90. }
  91. }
  92. l := eb.Len() - 2
  93. eb.SetByte(0, byte(l>>8))
  94. eb.SetByte(1, byte(l))
  95. eb.WriteByte(byte(length >> 8))
  96. eb.WriteByte(byte(length))
  97. eb.Write(b.Bytes())
  98. mb2Write = append(mb2Write, eb)
  99. }
  100. if mb2Write.IsEmpty() {
  101. return nil
  102. }
  103. return w.Writer.WriteMultiBuffer(mb2Write)
  104. }
  105. func NewPacketReader(reader io.Reader) *PacketReader {
  106. return &PacketReader{
  107. Reader: reader,
  108. cache: make([]byte, 2),
  109. }
  110. }
  111. type PacketReader struct {
  112. Reader io.Reader
  113. cache []byte
  114. }
  115. func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  116. for {
  117. if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
  118. return nil, err
  119. }
  120. l := int32(r.cache[0])<<8 | int32(r.cache[1])
  121. if l < 4 {
  122. return nil, io.EOF
  123. }
  124. b := buf.New()
  125. if _, err := b.ReadFullFrom(r.Reader, l); err != nil {
  126. b.Release()
  127. return nil, err
  128. }
  129. discard := false
  130. switch b.Byte(2) {
  131. case 2:
  132. if l > 4 && b.Byte(4) == 2 { // MUST check the flag first
  133. b.Advance(5)
  134. // b.Clear() will be called automatically if all data had been read.
  135. addr, port, err := AddrParser.ReadAddressPort(nil, b)
  136. if err != nil {
  137. b.Release()
  138. return nil, err
  139. }
  140. b.UDP = &net.Destination{
  141. Network: net.Network_UDP,
  142. Address: addr,
  143. Port: port,
  144. }
  145. }
  146. case 4:
  147. discard = true
  148. default:
  149. b.Release()
  150. return nil, io.EOF
  151. }
  152. b.Clear() // in case there is padding (empty bytes) attached
  153. if b.Byte(3) == 1 {
  154. if _, err := io.ReadFull(r.Reader, r.cache); err != nil {
  155. b.Release()
  156. return nil, err
  157. }
  158. length := int32(r.cache[0])<<8 | int32(r.cache[1])
  159. if length > 0 {
  160. if _, err := b.ReadFullFrom(r.Reader, length); err != nil {
  161. b.Release()
  162. return nil, err
  163. }
  164. if !discard {
  165. return buf.MultiBuffer{b}, nil
  166. }
  167. }
  168. }
  169. b.Release()
  170. }
  171. }