| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- package sniff
- import (
- "bytes"
- "context"
- "crypto"
- "crypto/aes"
- "encoding/binary"
- "io"
- "os"
- "github.com/sagernet/sing-box/adapter"
- "github.com/sagernet/sing-box/common/sniff/internal/qtls"
- C "github.com/sagernet/sing-box/constant"
- E "github.com/sagernet/sing/common/exceptions"
- "golang.org/x/crypto/hkdf"
- )
- func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
- reader := bytes.NewReader(packet)
- typeByte, err := reader.ReadByte()
- if err != nil {
- return nil, err
- }
- if typeByte&0x40 == 0 {
- return nil, E.New("bad type byte")
- }
- var versionNumber uint32
- err = binary.Read(reader, binary.BigEndian, &versionNumber)
- if err != nil {
- return nil, err
- }
- if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 {
- return nil, E.New("bad version")
- }
- packetType := (typeByte & 0x30) >> 4
- if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 {
- return nil, E.New("bad packet type")
- }
- destConnIDLen, err := reader.ReadByte()
- if err != nil {
- return nil, err
- }
- if destConnIDLen == 0 || destConnIDLen > 20 {
- return nil, E.New("bad destination connection id length")
- }
- destConnID := make([]byte, destConnIDLen)
- _, err = io.ReadFull(reader, destConnID)
- if err != nil {
- return nil, err
- }
- srcConnIDLen, err := reader.ReadByte()
- if err != nil {
- return nil, err
- }
- _, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen))
- if err != nil {
- return nil, err
- }
- tokenLen, err := qtls.ReadUvarint(reader)
- if err != nil {
- return nil, err
- }
- _, err = io.CopyN(io.Discard, reader, int64(tokenLen))
- if err != nil {
- return nil, err
- }
- packetLen, err := qtls.ReadUvarint(reader)
- if err != nil {
- return nil, err
- }
- hdrLen := int(reader.Size()) - reader.Len()
- if hdrLen+int(packetLen) > len(packet) {
- return nil, os.ErrInvalid
- }
- _, err = io.CopyN(io.Discard, reader, 4)
- if err != nil {
- return nil, err
- }
- pnBytes := make([]byte, aes.BlockSize)
- _, err = io.ReadFull(reader, pnBytes)
- if err != nil {
- return nil, err
- }
- var salt []byte
- switch versionNumber {
- case qtls.Version1:
- salt = qtls.SaltV1
- case qtls.Version2:
- salt = qtls.SaltV2
- default:
- salt = qtls.SaltOld
- }
- var hkdfHeaderProtectionLabel string
- switch versionNumber {
- case qtls.Version2:
- hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2
- default:
- hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1
- }
- initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
- secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
- hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16)
- block, err := aes.NewCipher(hpKey)
- if err != nil {
- return nil, err
- }
- mask := make([]byte, aes.BlockSize)
- block.Encrypt(mask, pnBytes)
- newPacket := make([]byte, len(packet))
- copy(newPacket, packet)
- newPacket[0] ^= mask[0] & 0xf
- for i := range newPacket[hdrLen : hdrLen+4] {
- newPacket[hdrLen+i] ^= mask[i+1]
- }
- packetNumberLength := newPacket[0]&0x3 + 1
- if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen {
- return nil, os.ErrInvalid
- }
- var packetNumber uint32
- switch packetNumberLength {
- case 1:
- packetNumber = uint32(newPacket[hdrLen])
- case 2:
- packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:]))
- case 3:
- packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16
- case 4:
- packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:])
- default:
- return nil, E.New("bad packet number length")
- }
- extHdrLen := hdrLen + int(packetNumberLength)
- copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
- data := newPacket[extHdrLen : int(packetLen)+hdrLen]
- var keyLabel string
- var ivLabel string
- switch versionNumber {
- case qtls.Version2:
- keyLabel = qtls.HKDFLabelKeyV2
- ivLabel = qtls.HKDFLabelIVV2
- default:
- keyLabel = qtls.HKDFLabelKeyV1
- ivLabel = qtls.HKDFLabelIVV1
- }
- key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
- iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
- cipher := qtls.AEADAESGCMTLS13(key, iv)
- nonce := make([]byte, int32(cipher.NonceSize()))
- binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
- decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen])
- if err != nil {
- return nil, err
- }
- var frameType byte
- var frameLen uint64
- var fragments []struct {
- offset uint64
- length uint64
- payload []byte
- }
- decryptedReader := bytes.NewReader(decrypted)
- for {
- frameType, err = decryptedReader.ReadByte()
- if err == io.EOF {
- break
- }
- switch frameType {
- case 0x00: // PADDING
- continue
- case 0x01: // PING
- continue
- case 0x02, 0x03: // ACK
- _, err = qtls.ReadUvarint(decryptedReader) // Largest Acknowledged
- if err != nil {
- return nil, err
- }
- _, err = qtls.ReadUvarint(decryptedReader) // ACK Delay
- if err != nil {
- return nil, err
- }
- ackRangeCount, err := qtls.ReadUvarint(decryptedReader) // ACK Range Count
- if err != nil {
- return nil, err
- }
- _, err = qtls.ReadUvarint(decryptedReader) // First ACK Range
- if err != nil {
- return nil, err
- }
- for i := 0; i < int(ackRangeCount); i++ {
- _, err = qtls.ReadUvarint(decryptedReader) // Gap
- if err != nil {
- return nil, err
- }
- _, err = qtls.ReadUvarint(decryptedReader) // ACK Range Length
- if err != nil {
- return nil, err
- }
- }
- if frameType == 0x03 {
- _, err = qtls.ReadUvarint(decryptedReader) // ECT0 Count
- if err != nil {
- return nil, err
- }
- _, err = qtls.ReadUvarint(decryptedReader) // ECT1 Count
- if err != nil {
- return nil, err
- }
- _, err = qtls.ReadUvarint(decryptedReader) // ECN-CE Count
- if err != nil {
- return nil, err
- }
- }
- case 0x06: // CRYPTO
- var offset uint64
- offset, err = qtls.ReadUvarint(decryptedReader)
- if err != nil {
- return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
- }
- var length uint64
- length, err = qtls.ReadUvarint(decryptedReader)
- if err != nil {
- return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
- }
- index := len(decrypted) - decryptedReader.Len()
- fragments = append(fragments, struct {
- offset uint64
- length uint64
- payload []byte
- }{offset, length, decrypted[index : index+int(length)]})
- frameLen += length
- _, err = decryptedReader.Seek(int64(length), io.SeekCurrent)
- if err != nil {
- return nil, err
- }
- case 0x1c: // CONNECTION_CLOSE
- _, err = qtls.ReadUvarint(decryptedReader) // Error Code
- if err != nil {
- return nil, err
- }
- _, err = qtls.ReadUvarint(decryptedReader) // Frame Type
- if err != nil {
- return nil, err
- }
- var length uint64
- length, err = qtls.ReadUvarint(decryptedReader) // Reason Phrase Length
- if err != nil {
- return nil, err
- }
- _, err = decryptedReader.Seek(int64(length), io.SeekCurrent) // Reason Phrase
- if err != nil {
- return nil, err
- }
- default:
- return nil, os.ErrInvalid
- }
- }
- tlsHdr := make([]byte, 5)
- tlsHdr[0] = 0x16
- binary.BigEndian.PutUint16(tlsHdr[1:], uint16(0x0303))
- binary.BigEndian.PutUint16(tlsHdr[3:], uint16(frameLen))
- var index uint64
- var length int
- var readers []io.Reader
- readers = append(readers, bytes.NewReader(tlsHdr))
- find:
- for {
- for _, fragment := range fragments {
- if fragment.offset == index {
- readers = append(readers, bytes.NewReader(fragment.payload))
- index = fragment.offset + fragment.length
- length++
- continue find
- }
- }
- if length == len(fragments) {
- break
- }
- return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, E.New("bad fragments")
- }
- metadata, err := TLSClientHello(ctx, io.MultiReader(readers...))
- if err != nil {
- return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
- }
- metadata.Protocol = C.ProtocolQUIC
- return metadata, nil
- }
|