| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 | 
							- package encryption
 
- import (
 
- 	"bytes"
 
- 	"crypto/cipher"
 
- 	"crypto/ecdh"
 
- 	"crypto/mlkem"
 
- 	"crypto/rand"
 
- 	"fmt"
 
- 	"io"
 
- 	"net"
 
- 	"sync"
 
- 	"time"
 
- 	"github.com/xtls/xray-core/common/crypto"
 
- 	"github.com/xtls/xray-core/common/errors"
 
- 	"lukechampine.com/blake3"
 
- )
 
- type ServerSession struct {
 
- 	PfsKey  []byte
 
- 	NfsKeys sync.Map
 
- }
 
- type ServerInstance struct {
 
- 	NfsSKeys      []any
 
- 	NfsPKeysBytes [][]byte
 
- 	Hash32s       [][32]byte
 
- 	RelaysLength  int
 
- 	XorMode       uint32
 
- 	SecondsFrom   int64
 
- 	SecondsTo     int64
 
- 	PaddingLens   [][3]int
 
- 	PaddingGaps   [][3]int
 
- 	RWLock   sync.RWMutex
 
- 	Closed   bool
 
- 	Lasts    map[int64][16]byte
 
- 	Tickets  [][16]byte
 
- 	Sessions map[[16]byte]*ServerSession
 
- }
 
- func (i *ServerInstance) Init(nfsSKeysBytes [][]byte, xorMode uint32, secondsFrom, secondsTo int64, padding string) (err error) {
 
- 	if i.NfsSKeys != nil {
 
- 		return errors.New("already initialized")
 
- 	}
 
- 	l := len(nfsSKeysBytes)
 
- 	if l == 0 {
 
- 		return errors.New("empty nfsSKeysBytes")
 
- 	}
 
- 	i.NfsSKeys = make([]any, l)
 
- 	i.NfsPKeysBytes = make([][]byte, l)
 
- 	i.Hash32s = make([][32]byte, l)
 
- 	for j, k := range nfsSKeysBytes {
 
- 		if len(k) == 32 {
 
- 			if i.NfsSKeys[j], err = ecdh.X25519().NewPrivateKey(k); err != nil {
 
- 				return
 
- 			}
 
- 			i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*ecdh.PrivateKey).PublicKey().Bytes()
 
- 			i.RelaysLength += 32 + 32
 
- 		} else {
 
- 			if i.NfsSKeys[j], err = mlkem.NewDecapsulationKey768(k); err != nil {
 
- 				return
 
- 			}
 
- 			i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*mlkem.DecapsulationKey768).EncapsulationKey().Bytes()
 
- 			i.RelaysLength += 1088 + 32
 
- 		}
 
- 		i.Hash32s[j] = blake3.Sum256(i.NfsPKeysBytes[j])
 
- 	}
 
- 	i.RelaysLength -= 32
 
- 	i.XorMode = xorMode
 
- 	i.SecondsFrom = secondsFrom
 
- 	i.SecondsTo = secondsTo
 
- 	err = ParsePadding(padding, &i.PaddingLens, &i.PaddingGaps)
 
- 	if err != nil {
 
- 		return
 
- 	}
 
- 	if i.SecondsFrom > 0 || i.SecondsTo > 0 {
 
- 		i.Lasts = make(map[int64][16]byte)
 
- 		i.Tickets = make([][16]byte, 0, 1024)
 
- 		i.Sessions = make(map[[16]byte]*ServerSession)
 
- 		go func() {
 
- 			for {
 
- 				time.Sleep(time.Minute)
 
- 				i.RWLock.Lock()
 
- 				if i.Closed {
 
- 					i.RWLock.Unlock()
 
- 					return
 
- 				}
 
- 				minute := time.Now().Unix() / 60
 
- 				last := i.Lasts[minute]
 
- 				delete(i.Lasts, minute)
 
- 				delete(i.Lasts, minute-1) // for insurance
 
- 				if last != [16]byte{} {
 
- 					for j, ticket := range i.Tickets {
 
- 						delete(i.Sessions, ticket)
 
- 						if ticket == last {
 
- 							i.Tickets = i.Tickets[j+1:]
 
- 							break
 
- 						}
 
- 					}
 
- 				}
 
- 				i.RWLock.Unlock()
 
- 			}
 
- 		}()
 
- 	}
 
- 	return
 
- }
 
- func (i *ServerInstance) Close() (err error) {
 
- 	i.RWLock.Lock()
 
- 	i.Closed = true
 
- 	i.RWLock.Unlock()
 
- 	return
 
- }
 
- func (i *ServerInstance) Handshake(conn net.Conn, fallback *[]byte) (*CommonConn, error) {
 
- 	if i.NfsSKeys == nil {
 
- 		return nil, errors.New("uninitialized")
 
- 	}
 
- 	c := NewCommonConn(conn, true)
 
- 	ivAndRelays := make([]byte, 16+i.RelaysLength)
 
- 	if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	if fallback != nil {
 
- 		*fallback = append(*fallback, ivAndRelays...)
 
- 	}
 
- 	iv := ivAndRelays[:16]
 
- 	relays := ivAndRelays[16:]
 
- 	var nfsKey []byte
 
- 	var lastCTR cipher.Stream
 
- 	for j, k := range i.NfsSKeys {
 
- 		if lastCTR != nil {
 
- 			lastCTR.XORKeyStream(relays, relays[:32]) // recover this relay
 
- 		}
 
- 		var index = 32
 
- 		if _, ok := k.(*mlkem.DecapsulationKey768); ok {
 
- 			index = 1088
 
- 		}
 
- 		if i.XorMode > 0 {
 
- 			NewCTR(i.NfsPKeysBytes[j], iv).XORKeyStream(relays, relays[:index]) // we don't use buggy elligator2, because we have PSK :)
 
- 		}
 
- 		if k, ok := k.(*ecdh.PrivateKey); ok {
 
- 			publicKey, err := ecdh.X25519().NewPublicKey(relays[:index])
 
- 			if err != nil {
 
- 				return nil, err
 
- 			}
 
- 			if publicKey.Bytes()[31] > 127 { // we just don't want the observer can change even one bit without breaking the connection, though it has nothing to do with security
 
- 				return nil, errors.New("the highest bit of the last byte of the peer-sent X25519 public key is not 0")
 
- 			}
 
- 			nfsKey, err = k.ECDH(publicKey)
 
- 			if err != nil {
 
- 				return nil, err
 
- 			}
 
- 		}
 
- 		if k, ok := k.(*mlkem.DecapsulationKey768); ok {
 
- 			var err error
 
- 			nfsKey, err = k.Decapsulate(relays[:index])
 
- 			if err != nil {
 
- 				return nil, err
 
- 			}
 
- 		}
 
- 		if j == len(i.NfsSKeys)-1 {
 
- 			break
 
- 		}
 
- 		relays = relays[index:]
 
- 		lastCTR = NewCTR(nfsKey, iv)
 
- 		lastCTR.XORKeyStream(relays, relays[:32])
 
- 		if !bytes.Equal(relays[:32], i.Hash32s[j+1][:]) {
 
- 			return nil, errors.New("unexpected hash32: ", fmt.Sprintf("%v", relays[:32]))
 
- 		}
 
- 		relays = relays[32:]
 
- 	}
 
- 	nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
 
- 	encryptedLength := make([]byte, 18)
 
- 	if _, err := io.ReadFull(conn, encryptedLength); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	if fallback != nil {
 
- 		*fallback = append(*fallback, encryptedLength...)
 
- 	}
 
- 	decryptedLength := make([]byte, 2)
 
- 	if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
 
- 		c.UseAES = !c.UseAES
 
- 		nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
 
- 		if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
 
- 			return nil, err
 
- 		}
 
- 	}
 
- 	if fallback != nil {
 
- 		*fallback = nil
 
- 	}
 
- 	length := DecodeLength(decryptedLength)
 
- 	if length == 32 {
 
- 		if i.SecondsFrom == 0 && i.SecondsTo == 0 {
 
- 			return nil, errors.New("0-RTT is not allowed")
 
- 		}
 
- 		encryptedTicket := make([]byte, 32)
 
- 		if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
 
- 			return nil, err
 
- 		}
 
- 		ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
 
- 		if err != nil {
 
- 			return nil, err
 
- 		}
 
- 		i.RWLock.RLock()
 
- 		s := i.Sessions[[16]byte(ticket)]
 
- 		i.RWLock.RUnlock()
 
- 		if s == nil {
 
- 			noises := make([]byte, crypto.RandBetween(1279, 2279)) // matches 1-RTT's server hello length for "random", though it is not important, just for example
 
- 			var err error
 
- 			for err == nil {
 
- 				rand.Read(noises)
 
- 				_, err = DecodeHeader(noises)
 
- 			}
 
- 			conn.Write(noises) // make client do new handshake
 
- 			return nil, errors.New("expired ticket")
 
- 		}
 
- 		if _, loaded := s.NfsKeys.LoadOrStore([32]byte(nfsKey), true); loaded { // prevents bad client also
 
- 			return nil, errors.New("replay detected")
 
- 		}
 
- 		c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
 
- 		c.PreWrite = make([]byte, 16)
 
- 		rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
 
- 		c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
 
- 		c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
 
- 		if i.XorMode == 2 {
 
- 			c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
 
- 		}
 
- 		return c, nil
 
- 	}
 
- 	if length < 1184+32+16 { // client may send more public keys in the future's version
 
- 		return nil, errors.New("too short length")
 
- 	}
 
- 	encryptedPfsPublicKey := make([]byte, length)
 
- 	if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
 
- 	if err != nil {
 
- 		return nil, err
 
- 	}
 
- 	mlkem768Key, encapsulatedPfsKey := mlkem768EKey.Encapsulate()
 
- 	peerX25519PKey, err := ecdh.X25519().NewPublicKey(encryptedPfsPublicKey[1184 : 1184+32])
 
- 	if err != nil {
 
- 		return nil, err
 
- 	}
 
- 	x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
 
- 	x25519Key, err := x25519SKey.ECDH(peerX25519PKey)
 
- 	if err != nil {
 
- 		return nil, err
 
- 	}
 
- 	pfsKey := make([]byte, 32+32) // no more capacity
 
- 	copy(pfsKey, mlkem768Key)
 
- 	copy(pfsKey[32:], x25519Key)
 
- 	pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
 
- 	c.UnitedKey = append(pfsKey, nfsKey...)
 
- 	c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
 
- 	c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
 
- 	ticket := [16]byte{}
 
- 	rand.Read(ticket[:])
 
- 	var seconds int64
 
- 	if i.SecondsTo == 0 {
 
- 		seconds = i.SecondsFrom * crypto.RandBetween(50, 100) / 100
 
- 	} else {
 
- 		seconds = crypto.RandBetween(i.SecondsFrom, i.SecondsTo)
 
- 	}
 
- 	copy(ticket[:], EncodeLength(int(seconds)))
 
- 	if seconds > 0 {
 
- 		i.RWLock.Lock()
 
- 		i.Lasts[(time.Now().Unix()+max(i.SecondsFrom, i.SecondsTo))/60+2] = ticket
 
- 		i.Tickets = append(i.Tickets, ticket)
 
- 		i.Sessions[ticket] = &ServerSession{PfsKey: pfsKey}
 
- 		i.RWLock.Unlock()
 
- 	}
 
- 	pfsKeyExchangeLength := 1088 + 32 + 16
 
- 	encryptedTicketLength := 32
 
- 	paddingLength, paddingLens, paddingGaps := CreatPadding(i.PaddingLens, i.PaddingGaps)
 
- 	serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
 
- 	nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
 
- 	c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket[:], nil)
 
- 	padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
 
- 	c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
 
- 	c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
 
- 	paddingLens[0] = pfsKeyExchangeLength + encryptedTicketLength + paddingLens[0]
 
- 	for i, l := range paddingLens { // sends padding in a fragmented way, to create variable traffic pattern, before inner VLESS flow takes control
 
- 		if l > 0 {
 
- 			if _, err := conn.Write(serverHello[:l]); err != nil {
 
- 				return nil, err
 
- 			}
 
- 			serverHello = serverHello[l:]
 
- 		}
 
- 		if len(paddingGaps) > i {
 
- 			time.Sleep(paddingGaps[i])
 
- 		}
 
- 	}
 
- 	// important: allows client sends padding slowly, eliminating 1-RTT's traffic pattern
 
- 	if _, err := io.ReadFull(conn, encryptedLength); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
 
- 	if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
 
- 		return nil, err
 
- 	}
 
- 	if i.XorMode == 2 {
 
- 		c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, ticket[:]), NewCTR(c.UnitedKey, iv), 0, 0)
 
- 	}
 
- 	return c, nil
 
- }
 
 
  |