quic.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. package sniff
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto"
  6. "crypto/aes"
  7. "crypto/tls"
  8. "encoding/binary"
  9. "io"
  10. "os"
  11. "github.com/sagernet/sing-box/adapter"
  12. "github.com/sagernet/sing-box/common/ja3"
  13. "github.com/sagernet/sing-box/common/sniff/internal/qtls"
  14. C "github.com/sagernet/sing-box/constant"
  15. "github.com/sagernet/sing/common/buf"
  16. E "github.com/sagernet/sing/common/exceptions"
  17. "golang.org/x/crypto/hkdf"
  18. )
  19. var ErrClientHelloFragmented = E.New("need more packet for chromium QUIC connection")
  20. func QUICClientHello(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
  21. reader := bytes.NewReader(packet)
  22. typeByte, err := reader.ReadByte()
  23. if err != nil {
  24. return err
  25. }
  26. if typeByte&0x40 == 0 {
  27. return E.New("bad type byte")
  28. }
  29. var versionNumber uint32
  30. err = binary.Read(reader, binary.BigEndian, &versionNumber)
  31. if err != nil {
  32. return err
  33. }
  34. if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 {
  35. return E.New("bad version")
  36. }
  37. packetType := (typeByte & 0x30) >> 4
  38. if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 {
  39. return E.New("bad packet type")
  40. }
  41. destConnIDLen, err := reader.ReadByte()
  42. if err != nil {
  43. return err
  44. }
  45. if destConnIDLen == 0 || destConnIDLen > 20 {
  46. return E.New("bad destination connection id length")
  47. }
  48. destConnID := make([]byte, destConnIDLen)
  49. _, err = io.ReadFull(reader, destConnID)
  50. if err != nil {
  51. return err
  52. }
  53. srcConnIDLen, err := reader.ReadByte()
  54. if err != nil {
  55. return err
  56. }
  57. _, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen))
  58. if err != nil {
  59. return err
  60. }
  61. tokenLen, err := qtls.ReadUvarint(reader)
  62. if err != nil {
  63. return err
  64. }
  65. _, err = io.CopyN(io.Discard, reader, int64(tokenLen))
  66. if err != nil {
  67. return err
  68. }
  69. packetLen, err := qtls.ReadUvarint(reader)
  70. if err != nil {
  71. return err
  72. }
  73. hdrLen := int(reader.Size()) - reader.Len()
  74. if hdrLen+int(packetLen) > len(packet) {
  75. return os.ErrInvalid
  76. }
  77. _, err = io.CopyN(io.Discard, reader, 4)
  78. if err != nil {
  79. return err
  80. }
  81. pnBytes := make([]byte, aes.BlockSize)
  82. _, err = io.ReadFull(reader, pnBytes)
  83. if err != nil {
  84. return err
  85. }
  86. var salt []byte
  87. switch versionNumber {
  88. case qtls.Version1:
  89. salt = qtls.SaltV1
  90. case qtls.Version2:
  91. salt = qtls.SaltV2
  92. default:
  93. salt = qtls.SaltOld
  94. }
  95. var hkdfHeaderProtectionLabel string
  96. switch versionNumber {
  97. case qtls.Version2:
  98. hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2
  99. default:
  100. hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1
  101. }
  102. initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
  103. secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
  104. hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16)
  105. block, err := aes.NewCipher(hpKey)
  106. if err != nil {
  107. return err
  108. }
  109. mask := make([]byte, aes.BlockSize)
  110. block.Encrypt(mask, pnBytes)
  111. newPacket := make([]byte, len(packet))
  112. copy(newPacket, packet)
  113. newPacket[0] ^= mask[0] & 0xf
  114. for i := range newPacket[hdrLen : hdrLen+4] {
  115. newPacket[hdrLen+i] ^= mask[i+1]
  116. }
  117. packetNumberLength := newPacket[0]&0x3 + 1
  118. if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen {
  119. return os.ErrInvalid
  120. }
  121. var packetNumber uint32
  122. switch packetNumberLength {
  123. case 1:
  124. packetNumber = uint32(newPacket[hdrLen])
  125. case 2:
  126. packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:]))
  127. case 3:
  128. packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16
  129. case 4:
  130. packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:])
  131. default:
  132. return E.New("bad packet number length")
  133. }
  134. extHdrLen := hdrLen + int(packetNumberLength)
  135. copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
  136. data := newPacket[extHdrLen : int(packetLen)+hdrLen]
  137. var keyLabel string
  138. var ivLabel string
  139. switch versionNumber {
  140. case qtls.Version2:
  141. keyLabel = qtls.HKDFLabelKeyV2
  142. ivLabel = qtls.HKDFLabelIVV2
  143. default:
  144. keyLabel = qtls.HKDFLabelKeyV1
  145. ivLabel = qtls.HKDFLabelIVV1
  146. }
  147. key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
  148. iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
  149. cipher := qtls.AEADAESGCMTLS13(key, iv)
  150. nonce := make([]byte, int32(cipher.NonceSize()))
  151. binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
  152. decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen])
  153. if err != nil {
  154. return err
  155. }
  156. var frameType byte
  157. var fragments []qCryptoFragment
  158. decryptedReader := bytes.NewReader(decrypted)
  159. const (
  160. frameTypePadding = 0x00
  161. frameTypePing = 0x01
  162. frameTypeAck = 0x02
  163. frameTypeAck2 = 0x03
  164. frameTypeCrypto = 0x06
  165. frameTypeConnectionClose = 0x1c
  166. )
  167. var frameTypeList []uint8
  168. for {
  169. frameType, err = decryptedReader.ReadByte()
  170. if err == io.EOF {
  171. break
  172. }
  173. frameTypeList = append(frameTypeList, frameType)
  174. switch frameType {
  175. case frameTypePadding:
  176. continue
  177. case frameTypePing:
  178. continue
  179. case frameTypeAck, frameTypeAck2:
  180. _, err = qtls.ReadUvarint(decryptedReader) // Largest Acknowledged
  181. if err != nil {
  182. return err
  183. }
  184. _, err = qtls.ReadUvarint(decryptedReader) // ACK Delay
  185. if err != nil {
  186. return err
  187. }
  188. ackRangeCount, err := qtls.ReadUvarint(decryptedReader) // ACK Range Count
  189. if err != nil {
  190. return err
  191. }
  192. _, err = qtls.ReadUvarint(decryptedReader) // First ACK Range
  193. if err != nil {
  194. return err
  195. }
  196. for i := 0; i < int(ackRangeCount); i++ {
  197. _, err = qtls.ReadUvarint(decryptedReader) // Gap
  198. if err != nil {
  199. return err
  200. }
  201. _, err = qtls.ReadUvarint(decryptedReader) // ACK Range Length
  202. if err != nil {
  203. return err
  204. }
  205. }
  206. if frameType == 0x03 {
  207. _, err = qtls.ReadUvarint(decryptedReader) // ECT0 Count
  208. if err != nil {
  209. return err
  210. }
  211. _, err = qtls.ReadUvarint(decryptedReader) // ECT1 Count
  212. if err != nil {
  213. return err
  214. }
  215. _, err = qtls.ReadUvarint(decryptedReader) // ECN-CE Count
  216. if err != nil {
  217. return err
  218. }
  219. }
  220. case frameTypeCrypto:
  221. var offset uint64
  222. offset, err = qtls.ReadUvarint(decryptedReader)
  223. if err != nil {
  224. return err
  225. }
  226. var length uint64
  227. length, err = qtls.ReadUvarint(decryptedReader)
  228. if err != nil {
  229. return err
  230. }
  231. index := len(decrypted) - decryptedReader.Len()
  232. fragments = append(fragments, qCryptoFragment{offset, length, decrypted[index : index+int(length)]})
  233. _, err = decryptedReader.Seek(int64(length), io.SeekCurrent)
  234. if err != nil {
  235. return err
  236. }
  237. case frameTypeConnectionClose:
  238. _, err = qtls.ReadUvarint(decryptedReader) // Error Code
  239. if err != nil {
  240. return err
  241. }
  242. _, err = qtls.ReadUvarint(decryptedReader) // Frame Type
  243. if err != nil {
  244. return err
  245. }
  246. var length uint64
  247. length, err = qtls.ReadUvarint(decryptedReader) // Reason Phrase Length
  248. if err != nil {
  249. return err
  250. }
  251. _, err = decryptedReader.Seek(int64(length), io.SeekCurrent) // Reason Phrase
  252. if err != nil {
  253. return err
  254. }
  255. default:
  256. return os.ErrInvalid
  257. }
  258. }
  259. if metadata.SniffContext != nil {
  260. fragments = append(fragments, metadata.SniffContext.([]qCryptoFragment)...)
  261. metadata.SniffContext = nil
  262. }
  263. var frameLen uint64
  264. for _, fragment := range fragments {
  265. frameLen += fragment.length
  266. }
  267. buffer := buf.NewSize(5 + int(frameLen))
  268. defer buffer.Release()
  269. buffer.WriteByte(0x16)
  270. binary.Write(buffer, binary.BigEndian, uint16(0x0303))
  271. binary.Write(buffer, binary.BigEndian, uint16(frameLen))
  272. var index uint64
  273. var length int
  274. find:
  275. for {
  276. for _, fragment := range fragments {
  277. if fragment.offset == index {
  278. buffer.Write(fragment.payload)
  279. index = fragment.offset + fragment.length
  280. length++
  281. continue find
  282. }
  283. }
  284. break
  285. }
  286. metadata.Protocol = C.ProtocolQUIC
  287. fingerprint, err := ja3.Compute(buffer.Bytes())
  288. if err != nil {
  289. metadata.Protocol = C.ProtocolQUIC
  290. metadata.Client = C.ClientChromium
  291. metadata.SniffContext = fragments
  292. return ErrClientHelloFragmented
  293. }
  294. metadata.Domain = fingerprint.ServerName
  295. for metadata.Client == "" {
  296. if len(frameTypeList) == 1 {
  297. metadata.Client = C.ClientFirefox
  298. break
  299. }
  300. if frameTypeList[0] == frameTypeCrypto && isZero(frameTypeList[1:]) {
  301. if len(fingerprint.Versions) == 2 && fingerprint.Versions[0]&ja3.GreaseBitmask == 0x0A0A &&
  302. len(fingerprint.EllipticCurves) == 5 && fingerprint.EllipticCurves[0]&ja3.GreaseBitmask == 0x0A0A {
  303. metadata.Client = C.ClientSafari
  304. break
  305. }
  306. if len(fingerprint.CipherSuites) == 1 && fingerprint.CipherSuites[0] == tls.TLS_AES_256_GCM_SHA384 &&
  307. len(fingerprint.EllipticCurves) == 1 && fingerprint.EllipticCurves[0] == uint16(tls.X25519) &&
  308. len(fingerprint.SignatureAlgorithms) == 1 && fingerprint.SignatureAlgorithms[0] == uint16(tls.ECDSAWithP256AndSHA256) {
  309. metadata.Client = C.ClientSafari
  310. break
  311. }
  312. }
  313. if frameTypeList[len(frameTypeList)-1] == frameTypeCrypto && isZero(frameTypeList[:len(frameTypeList)-1]) {
  314. metadata.Client = C.ClientQUICGo
  315. break
  316. }
  317. if count(frameTypeList, frameTypeCrypto) > 1 || count(frameTypeList, frameTypePing) > 0 {
  318. if maybeUQUIC(fingerprint) {
  319. metadata.Client = C.ClientQUICGo
  320. } else {
  321. metadata.Client = C.ClientChromium
  322. }
  323. break
  324. }
  325. metadata.Client = C.ClientUnknown
  326. //nolint:staticcheck
  327. break
  328. }
  329. return nil
  330. }
  331. func isZero(slices []uint8) bool {
  332. for _, slice := range slices {
  333. if slice != 0 {
  334. return false
  335. }
  336. }
  337. return true
  338. }
  339. func count(slices []uint8, value uint8) int {
  340. var times int
  341. for _, slice := range slices {
  342. if slice == value {
  343. times++
  344. }
  345. }
  346. return times
  347. }
  348. type qCryptoFragment struct {
  349. offset uint64
  350. length uint64
  351. payload []byte
  352. }