common.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package encryption
  2. import (
  3. "bytes"
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "fmt"
  7. "io"
  8. "net"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. "github.com/xtls/xray-core/common/crypto"
  14. "github.com/xtls/xray-core/common/errors"
  15. "golang.org/x/crypto/chacha20poly1305"
  16. "lukechampine.com/blake3"
  17. )
  18. var OutBytesPool = sync.Pool{
  19. New: func() any {
  20. return make([]byte, 5+8192+16)
  21. },
  22. }
  23. type CommonConn struct {
  24. net.Conn
  25. UseAES bool
  26. Client *ClientInstance
  27. UnitedKey []byte
  28. PreWrite []byte
  29. AEAD *AEAD
  30. PeerAEAD *AEAD
  31. PeerPadding []byte
  32. rawInput bytes.Buffer
  33. input bytes.Reader
  34. }
  35. func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
  36. return &CommonConn{
  37. Conn: conn,
  38. UseAES: useAES,
  39. }
  40. }
  41. func (c *CommonConn) Write(b []byte) (int, error) {
  42. if len(b) == 0 {
  43. return 0, nil
  44. }
  45. outBytes := OutBytesPool.Get().([]byte)
  46. defer OutBytesPool.Put(outBytes)
  47. for n := 0; n < len(b); {
  48. b := b[n:]
  49. if len(b) > 8192 {
  50. b = b[:8192] // for avoiding another copy() in peer's Read()
  51. }
  52. n += len(b)
  53. headerAndData := outBytes[:5+len(b)+16]
  54. EncodeHeader(headerAndData, len(b)+16)
  55. max := false
  56. if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
  57. max = true
  58. }
  59. c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
  60. if max {
  61. c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
  62. }
  63. if c.PreWrite != nil {
  64. headerAndData = append(c.PreWrite, headerAndData...)
  65. c.PreWrite = nil
  66. }
  67. if _, err := c.Conn.Write(headerAndData); err != nil {
  68. return 0, err
  69. }
  70. }
  71. return len(b), nil
  72. }
  73. func (c *CommonConn) Read(b []byte) (int, error) {
  74. if len(b) == 0 {
  75. return 0, nil
  76. }
  77. if c.PeerAEAD == nil { // client's 0-RTT
  78. serverRandom := make([]byte, 16)
  79. if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
  80. return 0, err
  81. }
  82. c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
  83. if xorConn, ok := c.Conn.(*XorConn); ok {
  84. xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
  85. }
  86. }
  87. if c.PeerPadding != nil { // client's 1-RTT
  88. if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
  89. return 0, err
  90. }
  91. if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
  92. return 0, err
  93. }
  94. c.PeerPadding = nil
  95. }
  96. if c.input.Len() > 0 {
  97. return c.input.Read(b)
  98. }
  99. peerHeader := [5]byte{}
  100. if _, err := io.ReadFull(c.Conn, peerHeader[:]); err != nil {
  101. return 0, err
  102. }
  103. l, err := DecodeHeader(peerHeader[:]) // l: 17~17000
  104. if err != nil {
  105. if c.Client != nil && strings.Contains(err.Error(), "invalid header: ") { // client's 0-RTT
  106. c.Client.RWLock.Lock()
  107. if bytes.HasPrefix(c.UnitedKey, c.Client.PfsKey) {
  108. c.Client.Expire = time.Now() // expired
  109. }
  110. c.Client.RWLock.Unlock()
  111. return 0, errors.New("new handshake needed")
  112. }
  113. return 0, err
  114. }
  115. c.Client = nil
  116. if c.rawInput.Cap() < l {
  117. c.rawInput.Grow(l) // no need to use sync.Pool, because we are always reading
  118. }
  119. peerData := c.rawInput.Bytes()[:l]
  120. if _, err := io.ReadFull(c.Conn, peerData); err != nil {
  121. return 0, err
  122. }
  123. dst := peerData[:l-16]
  124. if len(dst) <= len(b) {
  125. dst = b[:len(dst)] // avoids another copy()
  126. }
  127. var newAEAD *AEAD
  128. if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
  129. newAEAD = NewAEAD(append(peerHeader[:], peerData...), c.UnitedKey, c.UseAES)
  130. }
  131. _, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader[:])
  132. if newAEAD != nil {
  133. c.PeerAEAD = newAEAD
  134. }
  135. if err != nil {
  136. return 0, err
  137. }
  138. if len(dst) > len(b) {
  139. c.input.Reset(dst[copy(b, dst):])
  140. dst = b // for len(dst)
  141. }
  142. return len(dst), nil
  143. }
  144. type AEAD struct {
  145. cipher.AEAD
  146. Nonce [12]byte
  147. }
  148. func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
  149. k := make([]byte, 32)
  150. blake3.DeriveKey(k, string(ctx), key)
  151. var aead cipher.AEAD
  152. if useAES {
  153. block, _ := aes.NewCipher(k)
  154. aead, _ = cipher.NewGCM(block)
  155. } else {
  156. aead, _ = chacha20poly1305.New(k)
  157. }
  158. return &AEAD{AEAD: aead}
  159. }
  160. func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
  161. if nonce == nil {
  162. nonce = IncreaseNonce(a.Nonce[:])
  163. }
  164. return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
  165. }
  166. func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
  167. if nonce == nil {
  168. nonce = IncreaseNonce(a.Nonce[:])
  169. }
  170. return a.AEAD.Open(dst, nonce, ciphertext, additionalData)
  171. }
  172. func IncreaseNonce(nonce []byte) []byte {
  173. for i := range 12 {
  174. nonce[11-i]++
  175. if nonce[11-i] != 0 {
  176. break
  177. }
  178. }
  179. return nonce
  180. }
  181. var MaxNonce = bytes.Repeat([]byte{255}, 12)
  182. func EncodeLength(l int) []byte {
  183. return []byte{byte(l >> 8), byte(l)}
  184. }
  185. func DecodeLength(b []byte) int {
  186. return int(b[0])<<8 | int(b[1])
  187. }
  188. func EncodeHeader(h []byte, l int) {
  189. h[0] = 23
  190. h[1] = 3
  191. h[2] = 3
  192. h[3] = byte(l >> 8)
  193. h[4] = byte(l)
  194. }
  195. func DecodeHeader(h []byte) (l int, err error) {
  196. l = int(h[3])<<8 | int(h[4])
  197. if h[0] != 23 || h[1] != 3 || h[2] != 3 {
  198. l = 0
  199. }
  200. if l < 17 || l > 17000 { // TODO: TLSv1.3 max length
  201. err = errors.New("invalid header: " + fmt.Sprintf("%v", h[:5])) // DO NOT CHANGE: relied by client's Read()
  202. }
  203. return
  204. }
  205. func ParsePadding(padding string, paddingLens, paddingGaps *[][3]int) (err error) {
  206. if padding == "" {
  207. return
  208. }
  209. maxLen := 0
  210. for i, s := range strings.Split(padding, ".") {
  211. x := strings.Split(s, "-")
  212. if len(x) < 3 || x[0] == "" || x[1] == "" || x[2] == "" {
  213. return errors.New("invalid padding lenth/gap parameter: " + s)
  214. }
  215. y := [3]int{}
  216. if y[0], err = strconv.Atoi(x[0]); err != nil {
  217. return
  218. }
  219. if y[1], err = strconv.Atoi(x[1]); err != nil {
  220. return
  221. }
  222. if y[2], err = strconv.Atoi(x[2]); err != nil {
  223. return
  224. }
  225. if i == 0 && (y[0] < 100 || y[1] < 18+17 || y[2] < 18+17) {
  226. return errors.New("first padding length must not be smaller than 35")
  227. }
  228. if i%2 == 0 {
  229. *paddingLens = append(*paddingLens, y)
  230. maxLen += max(y[1], y[2])
  231. } else {
  232. *paddingGaps = append(*paddingGaps, y)
  233. }
  234. }
  235. if maxLen > 18+65535 {
  236. return errors.New("total padding length must not be larger than 65553")
  237. }
  238. return
  239. }
  240. func CreatPadding(paddingLens, paddingGaps [][3]int) (length int, lens []int, gaps []time.Duration) {
  241. if len(paddingLens) == 0 {
  242. paddingLens = [][3]int{{100, 111, 1111}, {50, 0, 3333}}
  243. paddingGaps = [][3]int{{75, 0, 111}}
  244. }
  245. for _, y := range paddingLens {
  246. l := 0
  247. if y[0] >= int(crypto.RandBetween(0, 100)) {
  248. l = int(crypto.RandBetween(int64(y[1]), int64(y[2])))
  249. }
  250. lens = append(lens, l)
  251. length += l
  252. }
  253. for _, y := range paddingGaps {
  254. g := 0
  255. if y[0] >= int(crypto.RandBetween(0, 100)) {
  256. g = int(crypto.RandBetween(int64(y[1]), int64(y[2])))
  257. }
  258. gaps = append(gaps, time.Duration(g)*time.Millisecond)
  259. }
  260. return
  261. }