server.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package encryption
  2. import (
  3. "bytes"
  4. "crypto/cipher"
  5. "crypto/ecdh"
  6. "crypto/mlkem"
  7. "crypto/rand"
  8. "fmt"
  9. "io"
  10. "net"
  11. "sync"
  12. "time"
  13. "github.com/xtls/xray-core/common/crypto"
  14. "github.com/xtls/xray-core/common/errors"
  15. "lukechampine.com/blake3"
  16. )
  17. type ServerSession struct {
  18. PfsKey []byte
  19. NfsKeys sync.Map
  20. }
  21. type ServerInstance struct {
  22. NfsSKeys []any
  23. NfsPKeysBytes [][]byte
  24. Hash32s [][32]byte
  25. RelaysLength int
  26. XorMode uint32
  27. SecondsFrom int64
  28. SecondsTo int64
  29. PaddingLens [][3]int
  30. PaddingGaps [][3]int
  31. RWLock sync.RWMutex
  32. Closed bool
  33. Lasts map[int64][16]byte
  34. Tickets [][16]byte
  35. Sessions map[[16]byte]*ServerSession
  36. }
  37. func (i *ServerInstance) Init(nfsSKeysBytes [][]byte, xorMode uint32, secondsFrom, secondsTo int64, padding string) (err error) {
  38. if i.NfsSKeys != nil {
  39. return errors.New("already initialized")
  40. }
  41. l := len(nfsSKeysBytes)
  42. if l == 0 {
  43. return errors.New("empty nfsSKeysBytes")
  44. }
  45. i.NfsSKeys = make([]any, l)
  46. i.NfsPKeysBytes = make([][]byte, l)
  47. i.Hash32s = make([][32]byte, l)
  48. for j, k := range nfsSKeysBytes {
  49. if len(k) == 32 {
  50. if i.NfsSKeys[j], err = ecdh.X25519().NewPrivateKey(k); err != nil {
  51. return
  52. }
  53. i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*ecdh.PrivateKey).PublicKey().Bytes()
  54. i.RelaysLength += 32 + 32
  55. } else {
  56. if i.NfsSKeys[j], err = mlkem.NewDecapsulationKey768(k); err != nil {
  57. return
  58. }
  59. i.NfsPKeysBytes[j] = i.NfsSKeys[j].(*mlkem.DecapsulationKey768).EncapsulationKey().Bytes()
  60. i.RelaysLength += 1088 + 32
  61. }
  62. i.Hash32s[j] = blake3.Sum256(i.NfsPKeysBytes[j])
  63. }
  64. i.RelaysLength -= 32
  65. i.XorMode = xorMode
  66. i.SecondsFrom = secondsFrom
  67. i.SecondsTo = secondsTo
  68. err = ParsePadding(padding, &i.PaddingLens, &i.PaddingGaps)
  69. if err != nil {
  70. return
  71. }
  72. if i.SecondsFrom > 0 || i.SecondsTo > 0 {
  73. i.Lasts = make(map[int64][16]byte)
  74. i.Tickets = make([][16]byte, 0, 1024)
  75. i.Sessions = make(map[[16]byte]*ServerSession)
  76. go func() {
  77. for {
  78. time.Sleep(time.Minute)
  79. i.RWLock.Lock()
  80. if i.Closed {
  81. i.RWLock.Unlock()
  82. return
  83. }
  84. minute := time.Now().Unix() / 60
  85. last := i.Lasts[minute]
  86. delete(i.Lasts, minute)
  87. delete(i.Lasts, minute-1) // for insurance
  88. if last != [16]byte{} {
  89. for j, ticket := range i.Tickets {
  90. delete(i.Sessions, ticket)
  91. if ticket == last {
  92. i.Tickets = i.Tickets[j+1:]
  93. break
  94. }
  95. }
  96. }
  97. i.RWLock.Unlock()
  98. }
  99. }()
  100. }
  101. return
  102. }
  103. func (i *ServerInstance) Close() (err error) {
  104. i.RWLock.Lock()
  105. i.Closed = true
  106. i.RWLock.Unlock()
  107. return
  108. }
  109. func (i *ServerInstance) Handshake(conn net.Conn, fallback *[]byte) (*CommonConn, error) {
  110. if i.NfsSKeys == nil {
  111. return nil, errors.New("uninitialized")
  112. }
  113. c := NewCommonConn(conn, true)
  114. ivAndRelays := make([]byte, 16+i.RelaysLength)
  115. if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
  116. return nil, err
  117. }
  118. if fallback != nil {
  119. *fallback = append(*fallback, ivAndRelays...)
  120. }
  121. iv := ivAndRelays[:16]
  122. relays := ivAndRelays[16:]
  123. var nfsKey []byte
  124. var lastCTR cipher.Stream
  125. for j, k := range i.NfsSKeys {
  126. if lastCTR != nil {
  127. lastCTR.XORKeyStream(relays, relays[:32]) // recover this relay
  128. }
  129. var index = 32
  130. if _, ok := k.(*mlkem.DecapsulationKey768); ok {
  131. index = 1088
  132. }
  133. if i.XorMode > 0 {
  134. NewCTR(i.NfsPKeysBytes[j], iv).XORKeyStream(relays, relays[:index]) // we don't use buggy elligator2, because we have PSK :)
  135. }
  136. if k, ok := k.(*ecdh.PrivateKey); ok {
  137. publicKey, err := ecdh.X25519().NewPublicKey(relays[:index])
  138. if err != nil {
  139. return nil, err
  140. }
  141. if publicKey.Bytes()[31] > 127 { // we just don't want the observer can change even one bit without breaking the connection, though it has nothing to do with security
  142. return nil, errors.New("the highest bit of the last byte of the peer-sent X25519 public key is not 0")
  143. }
  144. nfsKey, err = k.ECDH(publicKey)
  145. if err != nil {
  146. return nil, err
  147. }
  148. }
  149. if k, ok := k.(*mlkem.DecapsulationKey768); ok {
  150. var err error
  151. nfsKey, err = k.Decapsulate(relays[:index])
  152. if err != nil {
  153. return nil, err
  154. }
  155. }
  156. if j == len(i.NfsSKeys)-1 {
  157. break
  158. }
  159. relays = relays[index:]
  160. lastCTR = NewCTR(nfsKey, iv)
  161. lastCTR.XORKeyStream(relays, relays[:32])
  162. if !bytes.Equal(relays[:32], i.Hash32s[j+1][:]) {
  163. return nil, errors.New("unexpected hash32: ", fmt.Sprintf("%v", relays[:32]))
  164. }
  165. relays = relays[32:]
  166. }
  167. nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
  168. encryptedLength := make([]byte, 18)
  169. if _, err := io.ReadFull(conn, encryptedLength); err != nil {
  170. return nil, err
  171. }
  172. if fallback != nil {
  173. *fallback = append(*fallback, encryptedLength...)
  174. }
  175. decryptedLength := make([]byte, 2)
  176. if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
  177. c.UseAES = !c.UseAES
  178. nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
  179. if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
  180. return nil, err
  181. }
  182. }
  183. if fallback != nil {
  184. *fallback = nil
  185. }
  186. length := DecodeLength(decryptedLength)
  187. if length == 32 {
  188. if i.SecondsFrom == 0 && i.SecondsTo == 0 {
  189. return nil, errors.New("0-RTT is not allowed")
  190. }
  191. encryptedTicket := make([]byte, 32)
  192. if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
  193. return nil, err
  194. }
  195. ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
  196. if err != nil {
  197. return nil, err
  198. }
  199. i.RWLock.RLock()
  200. s := i.Sessions[[16]byte(ticket)]
  201. i.RWLock.RUnlock()
  202. if s == nil {
  203. noises := make([]byte, crypto.RandBetween(1279, 2279)) // matches 1-RTT's server hello length for "random", though it is not important, just for example
  204. var err error
  205. for err == nil {
  206. rand.Read(noises)
  207. _, err = DecodeHeader(noises)
  208. }
  209. conn.Write(noises) // make client do new handshake
  210. return nil, errors.New("expired ticket")
  211. }
  212. if _, loaded := s.NfsKeys.LoadOrStore([32]byte(nfsKey), true); loaded { // prevents bad client also
  213. return nil, errors.New("replay detected")
  214. }
  215. c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
  216. c.PreWrite = make([]byte, 16)
  217. rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
  218. c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
  219. c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
  220. if i.XorMode == 2 {
  221. c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
  222. }
  223. return c, nil
  224. }
  225. if length < 1184+32+16 { // client may send more public keys in the future's version
  226. return nil, errors.New("too short length")
  227. }
  228. encryptedPfsPublicKey := make([]byte, length)
  229. if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
  230. return nil, err
  231. }
  232. if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
  233. return nil, err
  234. }
  235. mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
  236. if err != nil {
  237. return nil, err
  238. }
  239. mlkem768Key, encapsulatedPfsKey := mlkem768EKey.Encapsulate()
  240. peerX25519PKey, err := ecdh.X25519().NewPublicKey(encryptedPfsPublicKey[1184 : 1184+32])
  241. if err != nil {
  242. return nil, err
  243. }
  244. x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
  245. x25519Key, err := x25519SKey.ECDH(peerX25519PKey)
  246. if err != nil {
  247. return nil, err
  248. }
  249. pfsKey := make([]byte, 32+32) // no more capacity
  250. copy(pfsKey, mlkem768Key)
  251. copy(pfsKey[32:], x25519Key)
  252. pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
  253. c.UnitedKey = append(pfsKey, nfsKey...)
  254. c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
  255. c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
  256. ticket := [16]byte{}
  257. rand.Read(ticket[:])
  258. var seconds int64
  259. if i.SecondsTo == 0 {
  260. seconds = i.SecondsFrom * crypto.RandBetween(50, 100) / 100
  261. } else {
  262. seconds = crypto.RandBetween(i.SecondsFrom, i.SecondsTo)
  263. }
  264. copy(ticket[:], EncodeLength(int(seconds)))
  265. if seconds > 0 {
  266. i.RWLock.Lock()
  267. i.Lasts[(time.Now().Unix()+max(i.SecondsFrom, i.SecondsTo))/60+2] = ticket
  268. i.Tickets = append(i.Tickets, ticket)
  269. i.Sessions[ticket] = &ServerSession{PfsKey: pfsKey}
  270. i.RWLock.Unlock()
  271. }
  272. pfsKeyExchangeLength := 1088 + 32 + 16
  273. encryptedTicketLength := 32
  274. paddingLength, paddingLens, paddingGaps := CreatPadding(i.PaddingLens, i.PaddingGaps)
  275. serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
  276. nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
  277. c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket[:], nil)
  278. padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
  279. c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
  280. c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
  281. paddingLens[0] = pfsKeyExchangeLength + encryptedTicketLength + paddingLens[0]
  282. for i, l := range paddingLens { // sends padding in a fragmented way, to create variable traffic pattern, before inner VLESS flow takes control
  283. if l > 0 {
  284. if _, err := conn.Write(serverHello[:l]); err != nil {
  285. return nil, err
  286. }
  287. serverHello = serverHello[l:]
  288. }
  289. if len(paddingGaps) > i {
  290. time.Sleep(paddingGaps[i])
  291. }
  292. }
  293. // important: allows client sends padding slowly, eliminating 1-RTT's traffic pattern
  294. if _, err := io.ReadFull(conn, encryptedLength); err != nil {
  295. return nil, err
  296. }
  297. if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
  298. return nil, err
  299. }
  300. encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
  301. if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
  302. return nil, err
  303. }
  304. if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
  305. return nil, err
  306. }
  307. if i.XorMode == 2 {
  308. c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, ticket[:]), NewCTR(c.UnitedKey, iv), 0, 0)
  309. }
  310. return c, nil
  311. }