common.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package encryption
  2. import (
  3. "bytes"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "fmt"
  7. "io"
  8. "net"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/xtls/xray-core/common/errors"
  13. "golang.org/x/crypto/chacha20poly1305"
  14. "lukechampine.com/blake3"
  15. )
  16. var OutBytesPool = sync.Pool{
  17. New: func() any {
  18. return make([]byte, 5+8192+16)
  19. },
  20. }
  21. type CommonConn struct {
  22. net.Conn
  23. UseAES bool
  24. Client *ClientInstance
  25. UnitedKey []byte
  26. PreWrite []byte
  27. AEAD *AEAD
  28. PeerAEAD *AEAD
  29. PeerPadding []byte
  30. PeerInBytes []byte
  31. PeerCache []byte
  32. }
  33. func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
  34. return &CommonConn{
  35. Conn: conn,
  36. UseAES: useAES,
  37. PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading
  38. }
  39. }
  40. func (c *CommonConn) Write(b []byte) (int, error) {
  41. if len(b) == 0 {
  42. return 0, nil
  43. }
  44. outBytes := OutBytesPool.Get().([]byte)
  45. defer OutBytesPool.Put(outBytes)
  46. for n := 0; n < len(b); {
  47. b := b[n:]
  48. if len(b) > 8192 {
  49. b = b[:8192] // for avoiding another copy() in peer's Read()
  50. }
  51. n += len(b)
  52. headerAndData := outBytes[:5+len(b)+16]
  53. EncodeHeader(headerAndData, len(b)+16)
  54. max := false
  55. if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
  56. max = true
  57. }
  58. c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
  59. if max {
  60. c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
  61. }
  62. if c.PreWrite != nil {
  63. headerAndData = append(c.PreWrite, headerAndData...)
  64. c.PreWrite = nil
  65. }
  66. if _, err := c.Conn.Write(headerAndData); err != nil {
  67. return 0, err
  68. }
  69. }
  70. return len(b), nil
  71. }
  72. func (c *CommonConn) Read(b []byte) (int, error) {
  73. if len(b) == 0 {
  74. return 0, nil
  75. }
  76. if c.PeerAEAD == nil { // client's 0-RTT
  77. serverRandom := make([]byte, 16)
  78. if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
  79. return 0, err
  80. }
  81. c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
  82. if xorConn, ok := c.Conn.(*XorConn); ok {
  83. xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
  84. }
  85. }
  86. if c.PeerPadding != nil { // client's 1-RTT
  87. if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
  88. return 0, err
  89. }
  90. if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
  91. return 0, err
  92. }
  93. c.PeerPadding = nil
  94. }
  95. if len(c.PeerCache) > 0 {
  96. n := copy(b, c.PeerCache)
  97. c.PeerCache = c.PeerCache[n:]
  98. return n, nil
  99. }
  100. peerHeader := c.PeerInBytes[:5]
  101. if _, err := io.ReadFull(c.Conn, peerHeader); err != nil {
  102. return 0, err
  103. }
  104. l, err := DecodeHeader(c.PeerInBytes[:5]) // l: 17~17000
  105. if err != nil {
  106. if c.Client != nil && strings.Contains(err.Error(), "invalid header: ") { // client's 0-RTT
  107. c.Client.RWLock.Lock()
  108. if bytes.HasPrefix(c.UnitedKey, c.Client.PfsKey) {
  109. c.Client.Expire = time.Now() // expired
  110. }
  111. c.Client.RWLock.Unlock()
  112. return 0, errors.New("new handshake needed")
  113. }
  114. return 0, err
  115. }
  116. c.Client = nil
  117. peerData := c.PeerInBytes[5 : 5+l]
  118. if _, err := io.ReadFull(c.Conn, peerData); err != nil {
  119. return 0, err
  120. }
  121. dst := peerData[:l-16]
  122. if len(dst) <= len(b) {
  123. dst = b[:len(dst)] // avoids another copy()
  124. }
  125. var newAEAD *AEAD
  126. if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
  127. newAEAD = NewAEAD(c.PeerInBytes[:5+l], c.UnitedKey, c.UseAES)
  128. }
  129. _, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader)
  130. if newAEAD != nil {
  131. c.PeerAEAD = newAEAD
  132. }
  133. if err != nil {
  134. return 0, err
  135. }
  136. if len(dst) > len(b) {
  137. c.PeerCache = dst[copy(b, dst):]
  138. dst = b // for len(dst)
  139. }
  140. return len(dst), nil
  141. }
  142. type AEAD struct {
  143. cipher.AEAD
  144. Nonce [12]byte
  145. }
  146. func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
  147. k := make([]byte, 32)
  148. blake3.DeriveKey(k, string(ctx), key)
  149. var aead cipher.AEAD
  150. if useAES {
  151. block, _ := aes.NewCipher(k)
  152. aead, _ = cipher.NewGCM(block)
  153. } else {
  154. aead, _ = chacha20poly1305.New(k)
  155. }
  156. return &AEAD{AEAD: aead}
  157. }
  158. func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
  159. if nonce == nil {
  160. nonce = IncreaseNonce(a.Nonce[:])
  161. }
  162. return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
  163. }
  164. func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
  165. if nonce == nil {
  166. nonce = IncreaseNonce(a.Nonce[:])
  167. }
  168. return a.AEAD.Open(dst, nonce, ciphertext, additionalData)
  169. }
  170. func IncreaseNonce(nonce []byte) []byte {
  171. for i := range 12 {
  172. nonce[11-i]++
  173. if nonce[11-i] != 0 {
  174. break
  175. }
  176. }
  177. return nonce
  178. }
  179. var MaxNonce = bytes.Repeat([]byte{255}, 12)
  180. func EncodeLength(l int) []byte {
  181. return []byte{byte(l >> 8), byte(l)}
  182. }
  183. func DecodeLength(b []byte) int {
  184. return int(b[0])<<8 | int(b[1])
  185. }
  186. func EncodeHeader(h []byte, l int) {
  187. h[0] = 23
  188. h[1] = 3
  189. h[2] = 3
  190. h[3] = byte(l >> 8)
  191. h[4] = byte(l)
  192. }
  193. func DecodeHeader(h []byte) (l int, err error) {
  194. l = int(h[3])<<8 | int(h[4])
  195. if h[0] != 23 || h[1] != 3 || h[2] != 3 {
  196. l = 0
  197. }
  198. if l < 17 || l > 17000 { // TODO: TLSv1.3 max length
  199. err = errors.New("invalid header: ", fmt.Sprintf("%v", h[:5])) // DO NOT CHANGE: relied by client's Read()
  200. }
  201. return
  202. }