| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- package encryption
- import (
- "bytes"
- "crypto/aes"
- "crypto/cipher"
- "fmt"
- "io"
- "net"
- "strconv"
- "strings"
- "sync"
- "time"
- "github.com/xtls/xray-core/common/crypto"
- "github.com/xtls/xray-core/common/errors"
- "golang.org/x/crypto/chacha20poly1305"
- "lukechampine.com/blake3"
- )
- var OutBytesPool = sync.Pool{
- New: func() any {
- return make([]byte, 5+8192+16)
- },
- }
- type CommonConn struct {
- net.Conn
- UseAES bool
- Client *ClientInstance
- UnitedKey []byte
- PreWrite []byte
- AEAD *AEAD
- PeerAEAD *AEAD
- PeerPadding []byte
- rawInput bytes.Buffer
- input bytes.Reader
- }
- func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
- return &CommonConn{
- Conn: conn,
- UseAES: useAES,
- }
- }
- func (c *CommonConn) Write(b []byte) (int, error) {
- if len(b) == 0 {
- return 0, nil
- }
- outBytes := OutBytesPool.Get().([]byte)
- defer OutBytesPool.Put(outBytes)
- for n := 0; n < len(b); {
- b := b[n:]
- if len(b) > 8192 {
- b = b[:8192] // for avoiding another copy() in peer's Read()
- }
- n += len(b)
- headerAndData := outBytes[:5+len(b)+16]
- EncodeHeader(headerAndData, len(b)+16)
- max := false
- if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
- max = true
- }
- c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
- if max {
- c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
- }
- if c.PreWrite != nil {
- headerAndData = append(c.PreWrite, headerAndData...)
- c.PreWrite = nil
- }
- if _, err := c.Conn.Write(headerAndData); err != nil {
- return 0, err
- }
- }
- return len(b), nil
- }
- func (c *CommonConn) Read(b []byte) (int, error) {
- if len(b) == 0 {
- return 0, nil
- }
- if c.PeerAEAD == nil { // client's 0-RTT
- serverRandom := make([]byte, 16)
- if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
- return 0, err
- }
- c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
- if xorConn, ok := c.Conn.(*XorConn); ok {
- xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
- }
- }
- if c.PeerPadding != nil { // client's 1-RTT
- if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
- return 0, err
- }
- if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
- return 0, err
- }
- c.PeerPadding = nil
- }
- if c.input.Len() > 0 {
- return c.input.Read(b)
- }
- peerHeader := [5]byte{}
- if _, err := io.ReadFull(c.Conn, peerHeader[:]); err != nil {
- return 0, err
- }
- l, err := DecodeHeader(peerHeader[:]) // l: 17~17000
- if err != nil {
- if c.Client != nil && strings.Contains(err.Error(), "invalid header: ") { // client's 0-RTT
- c.Client.RWLock.Lock()
- if bytes.HasPrefix(c.UnitedKey, c.Client.PfsKey) {
- c.Client.Expire = time.Now() // expired
- }
- c.Client.RWLock.Unlock()
- return 0, errors.New("new handshake needed")
- }
- return 0, err
- }
- c.Client = nil
- if c.rawInput.Cap() < l {
- c.rawInput.Grow(l) // no need to use sync.Pool, because we are always reading
- }
- peerData := c.rawInput.Bytes()[:l]
- if _, err := io.ReadFull(c.Conn, peerData); err != nil {
- return 0, err
- }
- dst := peerData[:l-16]
- if len(dst) <= len(b) {
- dst = b[:len(dst)] // avoids another copy()
- }
- var newAEAD *AEAD
- if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
- newAEAD = NewAEAD(append(peerHeader[:], peerData...), c.UnitedKey, c.UseAES)
- }
- _, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader[:])
- if newAEAD != nil {
- c.PeerAEAD = newAEAD
- }
- if err != nil {
- return 0, err
- }
- if len(dst) > len(b) {
- c.input.Reset(dst[copy(b, dst):])
- dst = b // for len(dst)
- }
- return len(dst), nil
- }
- type AEAD struct {
- cipher.AEAD
- Nonce [12]byte
- }
- func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
- k := make([]byte, 32)
- blake3.DeriveKey(k, string(ctx), key)
- var aead cipher.AEAD
- if useAES {
- block, _ := aes.NewCipher(k)
- aead, _ = cipher.NewGCM(block)
- } else {
- aead, _ = chacha20poly1305.New(k)
- }
- return &AEAD{AEAD: aead}
- }
- func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
- if nonce == nil {
- nonce = IncreaseNonce(a.Nonce[:])
- }
- return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
- }
- func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
- if nonce == nil {
- nonce = IncreaseNonce(a.Nonce[:])
- }
- return a.AEAD.Open(dst, nonce, ciphertext, additionalData)
- }
- func IncreaseNonce(nonce []byte) []byte {
- for i := range 12 {
- nonce[11-i]++
- if nonce[11-i] != 0 {
- break
- }
- }
- return nonce
- }
- var MaxNonce = bytes.Repeat([]byte{255}, 12)
- func EncodeLength(l int) []byte {
- return []byte{byte(l >> 8), byte(l)}
- }
- func DecodeLength(b []byte) int {
- return int(b[0])<<8 | int(b[1])
- }
- func EncodeHeader(h []byte, l int) {
- h[0] = 23
- h[1] = 3
- h[2] = 3
- h[3] = byte(l >> 8)
- h[4] = byte(l)
- }
- func DecodeHeader(h []byte) (l int, err error) {
- l = int(h[3])<<8 | int(h[4])
- if h[0] != 23 || h[1] != 3 || h[2] != 3 {
- l = 0
- }
- if l < 17 || l > 17000 { // TODO: TLSv1.3 max length
- err = errors.New("invalid header: " + fmt.Sprintf("%v", h[:5])) // DO NOT CHANGE: relied by client's Read()
- }
- return
- }
- func ParsePadding(padding string, paddingLens, paddingGaps *[][3]int) (err error) {
- if padding == "" {
- return
- }
- maxLen := 0
- for i, s := range strings.Split(padding, ".") {
- x := strings.Split(s, "-")
- if len(x) < 3 || x[0] == "" || x[1] == "" || x[2] == "" {
- return errors.New("invalid padding lenth/gap parameter: " + s)
- }
- y := [3]int{}
- if y[0], err = strconv.Atoi(x[0]); err != nil {
- return
- }
- if y[1], err = strconv.Atoi(x[1]); err != nil {
- return
- }
- if y[2], err = strconv.Atoi(x[2]); err != nil {
- return
- }
- if i == 0 && (y[0] < 100 || y[1] < 18+17 || y[2] < 18+17) {
- return errors.New("first padding length must not be smaller than 35")
- }
- if i%2 == 0 {
- *paddingLens = append(*paddingLens, y)
- maxLen += max(y[1], y[2])
- } else {
- *paddingGaps = append(*paddingGaps, y)
- }
- }
- if maxLen > 18+65535 {
- return errors.New("total padding length must not be larger than 65553")
- }
- return
- }
- func CreatPadding(paddingLens, paddingGaps [][3]int) (length int, lens []int, gaps []time.Duration) {
- if len(paddingLens) == 0 {
- paddingLens = [][3]int{{100, 111, 1111}, {50, 0, 3333}}
- paddingGaps = [][3]int{{75, 0, 111}}
- }
- for _, y := range paddingLens {
- l := 0
- if y[0] >= int(crypto.RandBetween(0, 100)) {
- l = int(crypto.RandBetween(int64(y[1]), int64(y[2])))
- }
- lens = append(lens, l)
- length += l
- }
- for _, y := range paddingGaps {
- g := 0
- if y[0] >= int(crypto.RandBetween(0, 100)) {
- g = int(crypto.RandBetween(int64(y[1]), int64(y[2])))
- }
- gaps = append(gaps, time.Duration(g)*time.Millisecond)
- }
- return
- }
|