quic.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package sniff
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto"
  6. "crypto/aes"
  7. "encoding/binary"
  8. "io"
  9. "os"
  10. "github.com/sagernet/sing-box/adapter"
  11. "github.com/sagernet/sing-box/common/sniff/internal/qtls"
  12. C "github.com/sagernet/sing-box/constant"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "golang.org/x/crypto/hkdf"
  15. )
  16. func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
  17. reader := bytes.NewReader(packet)
  18. typeByte, err := reader.ReadByte()
  19. if err != nil {
  20. return nil, err
  21. }
  22. if typeByte&0x40 == 0 {
  23. return nil, E.New("bad type byte")
  24. }
  25. var versionNumber uint32
  26. err = binary.Read(reader, binary.BigEndian, &versionNumber)
  27. if err != nil {
  28. return nil, err
  29. }
  30. if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 {
  31. return nil, E.New("bad version")
  32. }
  33. packetType := (typeByte & 0x30) >> 4
  34. if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 {
  35. return nil, E.New("bad packet type")
  36. }
  37. destConnIDLen, err := reader.ReadByte()
  38. if err != nil {
  39. return nil, err
  40. }
  41. if destConnIDLen == 0 || destConnIDLen > 20 {
  42. return nil, E.New("bad destination connection id length")
  43. }
  44. destConnID := make([]byte, destConnIDLen)
  45. _, err = io.ReadFull(reader, destConnID)
  46. if err != nil {
  47. return nil, err
  48. }
  49. srcConnIDLen, err := reader.ReadByte()
  50. if err != nil {
  51. return nil, err
  52. }
  53. _, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen))
  54. if err != nil {
  55. return nil, err
  56. }
  57. tokenLen, err := qtls.ReadUvarint(reader)
  58. if err != nil {
  59. return nil, err
  60. }
  61. _, err = io.CopyN(io.Discard, reader, int64(tokenLen))
  62. if err != nil {
  63. return nil, err
  64. }
  65. packetLen, err := qtls.ReadUvarint(reader)
  66. if err != nil {
  67. return nil, err
  68. }
  69. hdrLen := int(reader.Size()) - reader.Len()
  70. if hdrLen+int(packetLen) > len(packet) {
  71. return nil, os.ErrInvalid
  72. }
  73. _, err = io.CopyN(io.Discard, reader, 4)
  74. if err != nil {
  75. return nil, err
  76. }
  77. pnBytes := make([]byte, aes.BlockSize)
  78. _, err = io.ReadFull(reader, pnBytes)
  79. if err != nil {
  80. return nil, err
  81. }
  82. var salt []byte
  83. switch versionNumber {
  84. case qtls.Version1:
  85. salt = qtls.SaltV1
  86. case qtls.Version2:
  87. salt = qtls.SaltV2
  88. default:
  89. salt = qtls.SaltOld
  90. }
  91. var hkdfHeaderProtectionLabel string
  92. switch versionNumber {
  93. case qtls.Version2:
  94. hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2
  95. default:
  96. hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1
  97. }
  98. initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
  99. secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
  100. hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16)
  101. block, err := aes.NewCipher(hpKey)
  102. if err != nil {
  103. return nil, err
  104. }
  105. mask := make([]byte, aes.BlockSize)
  106. block.Encrypt(mask, pnBytes)
  107. newPacket := make([]byte, len(packet))
  108. copy(newPacket, packet)
  109. newPacket[0] ^= mask[0] & 0xf
  110. for i := range newPacket[hdrLen : hdrLen+4] {
  111. newPacket[hdrLen+i] ^= mask[i+1]
  112. }
  113. packetNumberLength := newPacket[0]&0x3 + 1
  114. if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen {
  115. return nil, os.ErrInvalid
  116. }
  117. var packetNumber uint32
  118. switch packetNumberLength {
  119. case 1:
  120. packetNumber = uint32(newPacket[hdrLen])
  121. case 2:
  122. packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:]))
  123. case 3:
  124. packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16
  125. case 4:
  126. packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:])
  127. default:
  128. return nil, E.New("bad packet number length")
  129. }
  130. extHdrLen := hdrLen + int(packetNumberLength)
  131. copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
  132. data := newPacket[extHdrLen : int(packetLen)+hdrLen]
  133. var keyLabel string
  134. var ivLabel string
  135. switch versionNumber {
  136. case qtls.Version2:
  137. keyLabel = qtls.HKDFLabelKeyV2
  138. ivLabel = qtls.HKDFLabelIVV2
  139. default:
  140. keyLabel = qtls.HKDFLabelKeyV1
  141. ivLabel = qtls.HKDFLabelIVV1
  142. }
  143. key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
  144. iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
  145. cipher := qtls.AEADAESGCMTLS13(key, iv)
  146. nonce := make([]byte, int32(cipher.NonceSize()))
  147. binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
  148. decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen])
  149. if err != nil {
  150. return nil, err
  151. }
  152. var frameType byte
  153. var frameLen uint64
  154. var fragments []struct {
  155. offset uint64
  156. length uint64
  157. payload []byte
  158. }
  159. decryptedReader := bytes.NewReader(decrypted)
  160. for {
  161. frameType, err = decryptedReader.ReadByte()
  162. if err == io.EOF {
  163. break
  164. }
  165. switch frameType {
  166. case 0x00: // PADDING
  167. continue
  168. case 0x01: // PING
  169. continue
  170. case 0x02, 0x03: // ACK
  171. _, err = qtls.ReadUvarint(decryptedReader) // Largest Acknowledged
  172. if err != nil {
  173. return nil, err
  174. }
  175. _, err = qtls.ReadUvarint(decryptedReader) // ACK Delay
  176. if err != nil {
  177. return nil, err
  178. }
  179. ackRangeCount, err := qtls.ReadUvarint(decryptedReader) // ACK Range Count
  180. if err != nil {
  181. return nil, err
  182. }
  183. _, err = qtls.ReadUvarint(decryptedReader) // First ACK Range
  184. if err != nil {
  185. return nil, err
  186. }
  187. for i := 0; i < int(ackRangeCount); i++ {
  188. _, err = qtls.ReadUvarint(decryptedReader) // Gap
  189. if err != nil {
  190. return nil, err
  191. }
  192. _, err = qtls.ReadUvarint(decryptedReader) // ACK Range Length
  193. if err != nil {
  194. return nil, err
  195. }
  196. }
  197. if frameType == 0x03 {
  198. _, err = qtls.ReadUvarint(decryptedReader) // ECT0 Count
  199. if err != nil {
  200. return nil, err
  201. }
  202. _, err = qtls.ReadUvarint(decryptedReader) // ECT1 Count
  203. if err != nil {
  204. return nil, err
  205. }
  206. _, err = qtls.ReadUvarint(decryptedReader) // ECN-CE Count
  207. if err != nil {
  208. return nil, err
  209. }
  210. }
  211. case 0x06: // CRYPTO
  212. var offset uint64
  213. offset, err = qtls.ReadUvarint(decryptedReader)
  214. if err != nil {
  215. return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
  216. }
  217. var length uint64
  218. length, err = qtls.ReadUvarint(decryptedReader)
  219. if err != nil {
  220. return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
  221. }
  222. index := len(decrypted) - decryptedReader.Len()
  223. fragments = append(fragments, struct {
  224. offset uint64
  225. length uint64
  226. payload []byte
  227. }{offset, length, decrypted[index : index+int(length)]})
  228. frameLen += length
  229. _, err = decryptedReader.Seek(int64(length), io.SeekCurrent)
  230. if err != nil {
  231. return nil, err
  232. }
  233. case 0x1c: // CONNECTION_CLOSE
  234. _, err = qtls.ReadUvarint(decryptedReader) // Error Code
  235. if err != nil {
  236. return nil, err
  237. }
  238. _, err = qtls.ReadUvarint(decryptedReader) // Frame Type
  239. if err != nil {
  240. return nil, err
  241. }
  242. var length uint64
  243. length, err = qtls.ReadUvarint(decryptedReader) // Reason Phrase Length
  244. if err != nil {
  245. return nil, err
  246. }
  247. _, err = decryptedReader.Seek(int64(length), io.SeekCurrent) // Reason Phrase
  248. if err != nil {
  249. return nil, err
  250. }
  251. default:
  252. return nil, os.ErrInvalid
  253. }
  254. }
  255. tlsHdr := make([]byte, 5)
  256. tlsHdr[0] = 0x16
  257. binary.BigEndian.PutUint16(tlsHdr[1:], uint16(0x0303))
  258. binary.BigEndian.PutUint16(tlsHdr[3:], uint16(frameLen))
  259. var index uint64
  260. var length int
  261. var readers []io.Reader
  262. readers = append(readers, bytes.NewReader(tlsHdr))
  263. find:
  264. for {
  265. for _, fragment := range fragments {
  266. if fragment.offset == index {
  267. readers = append(readers, bytes.NewReader(fragment.payload))
  268. index = fragment.offset + fragment.length
  269. length++
  270. continue find
  271. }
  272. }
  273. if length == len(fragments) {
  274. break
  275. }
  276. return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, E.New("bad fragments")
  277. }
  278. metadata, err := TLSClientHello(ctx, io.MultiReader(readers...))
  279. if err != nil {
  280. return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
  281. }
  282. metadata.Protocol = C.ProtocolQUIC
  283. return metadata, nil
  284. }