| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- package encryption
- import (
- "bytes"
- "crypto/cipher"
- "crypto/ecdh"
- "crypto/mlkem"
- "crypto/rand"
- "fmt"
- "io"
- "net"
- "sync"
- "time"
- "github.com/xtls/xray-core/common/crypto"
- "github.com/xtls/xray-core/common/errors"
- "lukechampine.com/blake3"
- )
- type ServerSession struct {
- PfsKey []byte
- NfsKeys sync.Map
- }
- type ServerInstance struct {
- NfsSKeys []any
- NfsPKeysBytes [][]byte
- Hash32s [][32]byte
- RelaysLength int
- XorMode uint32
- SecondsFrom int64
- SecondsTo int64
- PaddingLens [][3]int
- PaddingGaps [][3]int
- RWLock sync.RWMutex
- Closed bool
- Lasts map[int64][16]byte
- Tickets [][16]byte
- Sessions map[[16]byte]*ServerSession
- }
- func (i *ServerInstance) Init(nfsSKeysBytes [][]byte, xorMode uint32, secondsFrom, secondsTo int64, padding string) (err error) {
- if i.NfsSKeys != nil {
- return errors.New("already initialized")
- }
- l := len(nfsSKeysBytes)
- if l == 0 {
- return errors.New("empty nfsSKeysBytes")
- }
- i.NfsSKeys = make([]any, l)
- i.NfsPKeysBytes = make([][]byte, l)
- i.Hash32s = make([][32]byte, l)
- for j, k := range nfsSKeysBytes {
- if len(k) == 32 {
- if i.NfsSKeys[j], err = ecdh.X25519().NewPrivateKey(k); err != nil {
- return
- }
- i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*ecdh.PrivateKey).PublicKey().Bytes()
- i.RelaysLength += 32 + 32
- } else {
- if i.NfsSKeys[j], err = mlkem.NewDecapsulationKey768(k); err != nil {
- return
- }
- i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*mlkem.DecapsulationKey768).EncapsulationKey().Bytes()
- i.RelaysLength += 1088 + 32
- }
- i.Hash32s[j] = blake3.Sum256(i.NfsPKeysBytes[j])
- }
- i.RelaysLength -= 32
- i.XorMode = xorMode
- i.SecondsFrom = secondsFrom
- i.SecondsTo = secondsTo
- err = ParsePadding(padding, &i.PaddingLens, &i.PaddingGaps)
- if err != nil {
- return
- }
- if i.SecondsFrom > 0 || i.SecondsTo > 0 {
- i.Lasts = make(map[int64][16]byte)
- i.Tickets = make([][16]byte, 0, 1024)
- i.Sessions = make(map[[16]byte]*ServerSession)
- go func() {
- for {
- time.Sleep(time.Minute)
- i.RWLock.Lock()
- if i.Closed {
- i.RWLock.Unlock()
- return
- }
- minute := time.Now().Unix() / 60
- last := i.Lasts[minute]
- delete(i.Lasts, minute)
- delete(i.Lasts, minute-1) // for insurance
- if last != [16]byte{} {
- for j, ticket := range i.Tickets {
- delete(i.Sessions, ticket)
- if ticket == last {
- i.Tickets = i.Tickets[j+1:]
- break
- }
- }
- }
- i.RWLock.Unlock()
- }
- }()
- }
- return
- }
- func (i *ServerInstance) Close() (err error) {
- i.RWLock.Lock()
- i.Closed = true
- i.RWLock.Unlock()
- return
- }
- func (i *ServerInstance) Handshake(conn net.Conn, fallback *[]byte) (*CommonConn, error) {
- if i.NfsSKeys == nil {
- return nil, errors.New("uninitialized")
- }
- c := NewCommonConn(conn, true)
- ivAndRelays := make([]byte, 16+i.RelaysLength)
- if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
- return nil, err
- }
- if fallback != nil {
- *fallback = append(*fallback, ivAndRelays...)
- }
- iv := ivAndRelays[:16]
- relays := ivAndRelays[16:]
- var nfsKey []byte
- var lastCTR cipher.Stream
- for j, k := range i.NfsSKeys {
- if lastCTR != nil {
- lastCTR.XORKeyStream(relays, relays[:32]) // recover this relay
- }
- var index = 32
- if _, ok := k.(*mlkem.DecapsulationKey768); ok {
- index = 1088
- }
- if i.XorMode > 0 {
- NewCTR(i.NfsPKeysBytes[j], iv).XORKeyStream(relays, relays[:index]) // we don't use buggy elligator2, because we have PSK :)
- }
- if k, ok := k.(*ecdh.PrivateKey); ok {
- publicKey, err := ecdh.X25519().NewPublicKey(relays[:index])
- if err != nil {
- return nil, err
- }
- if publicKey.Bytes()[31] > 127 { // we just don't want the observer can change even one bit without breaking the connection, though it has nothing to do with security
- return nil, errors.New("the highest bit of the last byte of the peer-sent X25519 public key is not 0")
- }
- nfsKey, err = k.ECDH(publicKey)
- if err != nil {
- return nil, err
- }
- }
- if k, ok := k.(*mlkem.DecapsulationKey768); ok {
- var err error
- nfsKey, err = k.Decapsulate(relays[:index])
- if err != nil {
- return nil, err
- }
- }
- if j == len(i.NfsSKeys)-1 {
- break
- }
- relays = relays[index:]
- lastCTR = NewCTR(nfsKey, iv)
- lastCTR.XORKeyStream(relays, relays[:32])
- if !bytes.Equal(relays[:32], i.Hash32s[j+1][:]) {
- return nil, errors.New("unexpected hash32: ", fmt.Sprintf("%v", relays[:32]))
- }
- relays = relays[32:]
- }
- nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
- encryptedLength := make([]byte, 18)
- if _, err := io.ReadFull(conn, encryptedLength); err != nil {
- return nil, err
- }
- if fallback != nil {
- *fallback = append(*fallback, encryptedLength...)
- }
- decryptedLength := make([]byte, 2)
- if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
- c.UseAES = !c.UseAES
- nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
- if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
- return nil, err
- }
- }
- if fallback != nil {
- *fallback = nil
- }
- length := DecodeLength(decryptedLength)
- if length == 32 {
- if i.SecondsFrom == 0 && i.SecondsTo == 0 {
- return nil, errors.New("0-RTT is not allowed")
- }
- encryptedTicket := make([]byte, 32)
- if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
- return nil, err
- }
- ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
- if err != nil {
- return nil, err
- }
- i.RWLock.RLock()
- s := i.Sessions[[16]byte(ticket)]
- i.RWLock.RUnlock()
- if s == nil {
- noises := make([]byte, crypto.RandBetween(1279, 2279)) // matches 1-RTT's server hello length for "random", though it is not important, just for example
- var err error
- for err == nil {
- rand.Read(noises)
- _, err = DecodeHeader(noises)
- }
- conn.Write(noises) // make client do new handshake
- return nil, errors.New("expired ticket")
- }
- if _, loaded := s.NfsKeys.LoadOrStore([32]byte(nfsKey), true); loaded { // prevents bad client also
- return nil, errors.New("replay detected")
- }
- c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
- c.PreWrite = make([]byte, 16)
- rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
- c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
- c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
- if i.XorMode == 2 {
- c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
- }
- return c, nil
- }
- if length < 1184+32+16 { // client may send more public keys in the future's version
- return nil, errors.New("too short length")
- }
- encryptedPfsPublicKey := make([]byte, length)
- if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
- return nil, err
- }
- if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
- return nil, err
- }
- mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
- if err != nil {
- return nil, err
- }
- mlkem768Key, encapsulatedPfsKey := mlkem768EKey.Encapsulate()
- peerX25519PKey, err := ecdh.X25519().NewPublicKey(encryptedPfsPublicKey[1184 : 1184+32])
- if err != nil {
- return nil, err
- }
- x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
- x25519Key, err := x25519SKey.ECDH(peerX25519PKey)
- if err != nil {
- return nil, err
- }
- pfsKey := make([]byte, 32+32) // no more capacity
- copy(pfsKey, mlkem768Key)
- copy(pfsKey[32:], x25519Key)
- pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
- c.UnitedKey = append(pfsKey, nfsKey...)
- c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
- c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
- ticket := [16]byte{}
- rand.Read(ticket[:])
- var seconds int64
- if i.SecondsTo == 0 {
- seconds = i.SecondsFrom * crypto.RandBetween(50, 100) / 100
- } else {
- seconds = crypto.RandBetween(i.SecondsFrom, i.SecondsTo)
- }
- copy(ticket[:], EncodeLength(int(seconds)))
- if seconds > 0 {
- i.RWLock.Lock()
- i.Lasts[(time.Now().Unix()+max(i.SecondsFrom, i.SecondsTo))/60+2] = ticket
- i.Tickets = append(i.Tickets, ticket)
- i.Sessions[ticket] = &ServerSession{PfsKey: pfsKey}
- i.RWLock.Unlock()
- }
- pfsKeyExchangeLength := 1088 + 32 + 16
- encryptedTicketLength := 32
- paddingLength, paddingLens, paddingGaps := CreatPadding(i.PaddingLens, i.PaddingGaps)
- serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
- nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
- c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket[:], nil)
- padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
- c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
- c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
- paddingLens[0] = pfsKeyExchangeLength + encryptedTicketLength + paddingLens[0]
- for i, l := range paddingLens { // sends padding in a fragmented way, to create variable traffic pattern, before inner VLESS flow takes control
- if l > 0 {
- if _, err := conn.Write(serverHello[:l]); err != nil {
- return nil, err
- }
- serverHello = serverHello[l:]
- }
- if len(paddingGaps) > i {
- time.Sleep(paddingGaps[i])
- }
- }
- // important: allows client sends padding slowly, eliminating 1-RTT's traffic pattern
- if _, err := io.ReadFull(conn, encryptedLength); err != nil {
- return nil, err
- }
- if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
- return nil, err
- }
- encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
- if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
- return nil, err
- }
- if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
- return nil, err
- }
- if i.XorMode == 2 {
- c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, ticket[:]), NewCTR(c.UnitedKey, iv), 0, 0)
- }
- return c, nil
- }
|