| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- package encryption
- import (
- "bytes"
- "crypto/aes"
- "crypto/cipher"
- "fmt"
- "io"
- "net"
- "strings"
- "sync"
- "time"
- "github.com/xtls/xray-core/common/errors"
- "golang.org/x/crypto/chacha20poly1305"
- "lukechampine.com/blake3"
- )
- var OutBytesPool = sync.Pool{
- New: func() any {
- return make([]byte, 5+8192+16)
- },
- }
- type CommonConn struct {
- net.Conn
- UseAES bool
- Client *ClientInstance
- UnitedKey []byte
- PreWrite []byte
- AEAD *AEAD
- PeerAEAD *AEAD
- PeerPadding []byte
- PeerInBytes []byte
- PeerCache []byte
- }
- func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
- return &CommonConn{
- Conn: conn,
- UseAES: useAES,
- PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading
- }
- }
- func (c *CommonConn) Write(b []byte) (int, error) {
- if len(b) == 0 {
- return 0, nil
- }
- outBytes := OutBytesPool.Get().([]byte)
- defer OutBytesPool.Put(outBytes)
- for n := 0; n < len(b); {
- b := b[n:]
- if len(b) > 8192 {
- b = b[:8192] // for avoiding another copy() in peer's Read()
- }
- n += len(b)
- headerAndData := outBytes[:5+len(b)+16]
- EncodeHeader(headerAndData, len(b)+16)
- max := false
- if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
- max = true
- }
- c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
- if max {
- c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
- }
- if c.PreWrite != nil {
- headerAndData = append(c.PreWrite, headerAndData...)
- c.PreWrite = nil
- }
- if _, err := c.Conn.Write(headerAndData); err != nil {
- return 0, err
- }
- }
- return len(b), nil
- }
- func (c *CommonConn) Read(b []byte) (int, error) {
- if len(b) == 0 {
- return 0, nil
- }
- if c.PeerAEAD == nil { // client's 0-RTT
- serverRandom := make([]byte, 16)
- if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
- return 0, err
- }
- c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
- if xorConn, ok := c.Conn.(*XorConn); ok {
- xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
- }
- }
- if c.PeerPadding != nil { // client's 1-RTT
- if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
- return 0, err
- }
- if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
- return 0, err
- }
- c.PeerPadding = nil
- }
- if len(c.PeerCache) > 0 {
- n := copy(b, c.PeerCache)
- c.PeerCache = c.PeerCache[n:]
- return n, nil
- }
- peerHeader := c.PeerInBytes[:5]
- if _, err := io.ReadFull(c.Conn, peerHeader); err != nil {
- return 0, err
- }
- l, err := DecodeHeader(c.PeerInBytes[:5]) // l: 17~17000
- if err != nil {
- if c.Client != nil && strings.Contains(err.Error(), "invalid header: ") { // client's 0-RTT
- c.Client.RWLock.Lock()
- if bytes.HasPrefix(c.UnitedKey, c.Client.PfsKey) {
- c.Client.Expire = time.Now() // expired
- }
- c.Client.RWLock.Unlock()
- return 0, errors.New("new handshake needed")
- }
- return 0, err
- }
- c.Client = nil
- peerData := c.PeerInBytes[5 : 5+l]
- if _, err := io.ReadFull(c.Conn, peerData); err != nil {
- return 0, err
- }
- dst := peerData[:l-16]
- if len(dst) <= len(b) {
- dst = b[:len(dst)] // avoids another copy()
- }
- var newAEAD *AEAD
- if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
- newAEAD = NewAEAD(c.PeerInBytes[:5+l], c.UnitedKey, c.UseAES)
- }
- _, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader)
- if newAEAD != nil {
- c.PeerAEAD = newAEAD
- }
- if err != nil {
- return 0, err
- }
- if len(dst) > len(b) {
- c.PeerCache = dst[copy(b, dst):]
- dst = b // for len(dst)
- }
- return len(dst), nil
- }
- type AEAD struct {
- cipher.AEAD
- Nonce [12]byte
- }
- func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
- k := make([]byte, 32)
- blake3.DeriveKey(k, string(ctx), key)
- var aead cipher.AEAD
- if useAES {
- block, _ := aes.NewCipher(k)
- aead, _ = cipher.NewGCM(block)
- } else {
- aead, _ = chacha20poly1305.New(k)
- }
- return &AEAD{AEAD: aead}
- }
- func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
- if nonce == nil {
- nonce = IncreaseNonce(a.Nonce[:])
- }
- return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
- }
- func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
- if nonce == nil {
- nonce = IncreaseNonce(a.Nonce[:])
- }
- return a.AEAD.Open(dst, nonce, ciphertext, additionalData)
- }
- func IncreaseNonce(nonce []byte) []byte {
- for i := range 12 {
- nonce[11-i]++
- if nonce[11-i] != 0 {
- break
- }
- }
- return nonce
- }
- var MaxNonce = bytes.Repeat([]byte{255}, 12)
- func EncodeLength(l int) []byte {
- return []byte{byte(l >> 8), byte(l)}
- }
- func DecodeLength(b []byte) int {
- return int(b[0])<<8 | int(b[1])
- }
- func EncodeHeader(h []byte, l int) {
- h[0] = 23
- h[1] = 3
- h[2] = 3
- h[3] = byte(l >> 8)
- h[4] = byte(l)
- }
- func DecodeHeader(h []byte) (l int, err error) {
- l = int(h[3])<<8 | int(h[4])
- if h[0] != 23 || h[1] != 3 || h[2] != 3 {
- l = 0
- }
- if l < 17 || l > 17000 { // TODO: TLSv1.3 max length
- err = errors.New("invalid header: ", fmt.Sprintf("%v", h[:5])) // DO NOT CHANGE: relied by client's Read()
- }
- return
- }
|