Browse Source

VLESS Encryption: Re-add automatically ChaCha20-Poly1305

https://github.com/XTLS/Xray-core/pull/5067#issuecomment-3234892060

Fixes https://github.com/XTLS/Xray-core/pull/4952#issuecomment-3234083367 for cheap routers
RPRX 3 months ago
parent
commit
82ea7a3cc5
3 changed files with 67 additions and 54 deletions
  1. 15 14
      proxy/vless/encryption/client.go
  2. 30 23
      proxy/vless/encryption/common.go
  3. 22 17
      proxy/vless/encryption/server.go

+ 15 - 14
proxy/vless/encryption/client.go

@@ -12,6 +12,7 @@ import (
 
 	"github.com/xtls/xray-core/common/crypto"
 	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/protocol"
 	"lukechampine.com/blake3"
 )
 
@@ -66,7 +67,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	if i.NfsPKeys == nil {
 		return nil, errors.New("uninitialized")
 	}
-	c := NewCommonConn(conn)
+	c := NewCommonConn(conn, protocol.HasAESGCMHardwareSupport)
 
 	ivAndRealysLength := 16 + i.RelaysLength
 	pfsKeyExchangeLength := 18 + 1184 + 32 + 16
@@ -108,18 +109,18 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 		lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:])
 		relays = relays[index+32:]
 	}
-	nfsGCM := NewGCM(iv, nfsKey)
+	nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
 
 	if i.Seconds > 0 {
 		i.RWLock.RLock()
 		if time.Now().Before(i.Expire) {
 			c.Client = i
 			c.UnitedKey = append(i.PfsKey, nfsKey...) // different unitedKey for each connection
-			nfsGCM.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil)
-			nfsGCM.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil)
+			nfsAEAD.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil)
+			nfsAEAD.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil)
 			i.RWLock.RUnlock()
 			c.PreWrite = clientHello[:ivAndRealysLength+18+32]
-			c.GCM = NewGCM(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey)
+			c.AEAD = NewAEAD(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey, c.UseAES)
 			if i.XorMode == 2 {
 				c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 16)
 			}
@@ -129,15 +130,15 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	}
 
 	pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength]
-	nfsGCM.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil)
+	nfsAEAD.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil)
 	mlkem768DKey, _ := mlkem.GenerateKey768()
 	x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
 	pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...)
-	nfsGCM.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil)
+	nfsAEAD.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil)
 
 	padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:]
-	nfsGCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
-	nfsGCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
+	nfsAEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
+	nfsAEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
 
 	if _, err := conn.Write(clientHello); err != nil {
 		return nil, err
@@ -148,7 +149,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
 		return nil, err
 	}
-	nfsGCM.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil)
+	nfsAEAD.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil)
 	mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088])
 	if err != nil {
 		return nil, err
@@ -165,14 +166,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	copy(pfsKey, mlkem768Key)
 	copy(pfsKey[32:], x25519Key)
 	c.UnitedKey = append(pfsKey, nfsKey...)
-	c.GCM = NewGCM(pfsPublicKey, c.UnitedKey)
-	c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1088+32], c.UnitedKey)
+	c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
+	c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1088+32], c.UnitedKey, c.UseAES)
 
 	encryptedTicket := make([]byte, 32)
 	if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
 		return nil, err
 	}
-	if _, err := c.PeerGCM.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil {
+	if _, err := c.PeerAEAD.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil {
 		return nil, err
 	}
 	seconds := DecodeLength(encryptedTicket)
@@ -189,7 +190,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	if _, err := io.ReadFull(conn, encryptedLength); err != nil {
 		return nil, err
 	}
-	if _, err := c.PeerGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
+	if _, err := c.PeerAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
 		return nil, err
 	}
 	length := DecodeLength(encryptedLength[:2])

+ 30 - 23
proxy/vless/encryption/common.go

@@ -12,6 +12,7 @@ import (
 	"time"
 
 	"github.com/xtls/xray-core/common/errors"
+	"golang.org/x/crypto/chacha20poly1305"
 	"lukechampine.com/blake3"
 )
 
@@ -23,19 +24,21 @@ var OutBytesPool = sync.Pool{
 
 type CommonConn struct {
 	net.Conn
+	UseAES      bool
 	Client      *ClientInstance
 	UnitedKey   []byte
 	PreWrite    []byte
-	GCM         *GCM
-	PeerGCM     *GCM
+	AEAD        *AEAD
+	PeerAEAD    *AEAD
 	PeerPadding []byte
 	PeerInBytes []byte
 	PeerCache   []byte
 }
 
-func NewCommonConn(conn net.Conn) *CommonConn {
+func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
 	return &CommonConn{
 		Conn:        conn,
+		UseAES:      useAES,
 		PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading
 	}
 }
@@ -55,12 +58,12 @@ func (c *CommonConn) Write(b []byte) (int, error) {
 		headerAndData := outBytes[:5+len(b)+16]
 		EncodeHeader(headerAndData, len(b)+16)
 		max := false
-		if bytes.Equal(c.GCM.Nonce[:], MaxNonce) {
+		if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
 			max = true
 		}
-		c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5])
+		c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
 		if max {
-			c.GCM = NewGCM(headerAndData, c.UnitedKey)
+			c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
 		}
 		if c.PreWrite != nil {
 			headerAndData = append(c.PreWrite, headerAndData...)
@@ -77,12 +80,12 @@ func (c *CommonConn) Read(b []byte) (int, error) {
 	if len(b) == 0 {
 		return 0, nil
 	}
-	if c.PeerGCM == nil { // client's 0-RTT
+	if c.PeerAEAD == nil { // client's 0-RTT
 		serverRandom := make([]byte, 16)
 		if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
 			return 0, err
 		}
-		c.PeerGCM = NewGCM(serverRandom, c.UnitedKey)
+		c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
 		if xorConn, ok := c.Conn.(*XorConn); ok {
 			xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
 		}
@@ -91,7 +94,7 @@ func (c *CommonConn) Read(b []byte) (int, error) {
 		if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
 			return 0, err
 		}
-		if _, err := c.PeerGCM.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
+		if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
 			return 0, err
 		}
 		c.PeerPadding = nil
@@ -126,13 +129,13 @@ func (c *CommonConn) Read(b []byte) (int, error) {
 	if len(dst) <= len(b) {
 		dst = b[:len(dst)] // avoids another copy()
 	}
-	var newGCM *GCM
-	if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) {
-		newGCM = NewGCM(c.PeerInBytes[:5+l], c.UnitedKey)
+	var newAEAD *AEAD
+	if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
+		newAEAD = NewAEAD(c.PeerInBytes[:5+l], c.UnitedKey, c.UseAES)
 	}
-	_, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader)
-	if newGCM != nil {
-		c.PeerGCM = newGCM
+	_, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader)
+	if newAEAD != nil {
+		c.PeerAEAD = newAEAD
 	}
 	if err != nil {
 		return 0, err
@@ -144,28 +147,32 @@ func (c *CommonConn) Read(b []byte) (int, error) {
 	return len(dst), nil
 }
 
-type GCM struct {
+type AEAD struct {
 	cipher.AEAD
 	Nonce [12]byte
 }
 
-func NewGCM(ctx, key []byte) *GCM {
+func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
 	k := make([]byte, 32)
 	blake3.DeriveKey(k, string(ctx), key)
-	block, _ := aes.NewCipher(k)
-	aead, _ := cipher.NewGCM(block)
-	return &GCM{AEAD: aead}
-	//chacha20poly1305.New()
+	var aead cipher.AEAD
+	if useAES {
+		block, _ := aes.NewCipher(k)
+		aead, _ = cipher.NewGCM(block)
+	} else {
+		aead, _ = chacha20poly1305.New(k)
+	}
+	return &AEAD{AEAD: aead}
 }
 
-func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
+func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
 	if nonce == nil {
 		nonce = IncreaseNonce(a.Nonce[:])
 	}
 	return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
 }
 
-func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
+func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
 	if nonce == nil {
 		nonce = IncreaseNonce(a.Nonce[:])
 	}

+ 22 - 17
proxy/vless/encryption/server.go

@@ -102,7 +102,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	if i.NfsSKeys == nil {
 		return nil, errors.New("uninitialized")
 	}
-	c := NewCommonConn(conn)
+	c := NewCommonConn(conn, true)
 
 	ivAndRelays := make([]byte, 16+i.RelaysLength)
 	if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
@@ -151,16 +151,21 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 		}
 		relays = relays[32:]
 	}
-	nfsGCM := NewGCM(iv, nfsKey)
+	nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
 
 	encryptedLength := make([]byte, 18)
 	if _, err := io.ReadFull(conn, encryptedLength); err != nil {
 		return nil, err
 	}
-	if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
-		return nil, err
+	decryptedLength := make([]byte, 2)
+	if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
+		c.UseAES = !c.UseAES
+		nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
+		if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
+			return nil, err
+		}
 	}
-	length := DecodeLength(encryptedLength[:2])
+	length := DecodeLength(decryptedLength)
 
 	if length == 32 {
 		if i.Seconds == 0 {
@@ -170,7 +175,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 		if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
 			return nil, err
 		}
-		ticket, err := nfsGCM.Open(nil, nil, encryptedTicket, nil)
+		ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
 		if err != nil {
 			return nil, err
 		}
@@ -193,8 +198,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 		c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
 		c.PreWrite = make([]byte, 16)
 		rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
-		c.GCM = NewGCM(c.PreWrite, c.UnitedKey)
-		c.PeerGCM = NewGCM(encryptedTicket, c.UnitedKey) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
+		c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
+		c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
 		if i.XorMode == 2 {
 			c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
 		}
@@ -208,7 +213,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
 		return nil, err
 	}
-	if _, err := nfsGCM.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
+	if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
 		return nil, err
 	}
 	mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
@@ -230,8 +235,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	copy(pfsKey[32:], x25519Key)
 	pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
 	c.UnitedKey = append(pfsKey, nfsKey...)
-	c.GCM = NewGCM(pfsPublicKey, c.UnitedKey)
-	c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1184+32], c.UnitedKey)
+	c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
+	c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
 	ticket := make([]byte, 16)
 	rand.Read(ticket)
 	copy(ticket, EncodeLength(int(i.Seconds*4/5)))
@@ -240,11 +245,11 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	encryptedTicketLength := 32
 	paddingLength := int(crypto.RandBetween(100, 1000))
 	serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
-	nfsGCM.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
-	c.GCM.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil)
+	nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
+	c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil)
 	padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
-	c.GCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
-	c.GCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
+	c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
+	c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
 
 	if _, err := conn.Write(serverHello); err != nil {
 		return nil, err
@@ -264,14 +269,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
 	if _, err := io.ReadFull(conn, encryptedLength); err != nil {
 		return nil, err
 	}
-	if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
+	if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
 		return nil, err
 	}
 	encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
 	if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
 		return nil, err
 	}
-	if _, err := nfsGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
+	if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
 		return nil, err
 	}