protocol.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package shadowsocks
  2. import (
  3. "crypto/hmac"
  4. "crypto/rand"
  5. "crypto/sha256"
  6. "hash/crc32"
  7. "io"
  8. "io/ioutil"
  9. "github.com/xtls/xray-core/common"
  10. "github.com/xtls/xray-core/common/buf"
  11. "github.com/xtls/xray-core/common/dice"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/protocol"
  14. )
  15. const (
  16. Version = 1
  17. )
  18. var addrParser = protocol.NewAddressParser(
  19. protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
  20. protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
  21. protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
  22. protocol.WithAddressTypeParser(func(b byte) byte {
  23. return b & 0x0F
  24. }),
  25. )
  26. // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
  27. func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
  28. account := user.Account.(*MemoryAccount)
  29. hashkdf := hmac.New(sha256.New, []byte("SSBSKDF"))
  30. hashkdf.Write(account.Key)
  31. behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil))
  32. behaviorRand := dice.NewDeterministicDice(int64(behaviorSeed))
  33. BaseDrainSize := behaviorRand.Roll(3266)
  34. RandDrainMax := behaviorRand.Roll(64) + 1
  35. RandDrainRolled := dice.Roll(RandDrainMax)
  36. DrainSize := BaseDrainSize + 16 + 38 + RandDrainRolled
  37. readSizeRemain := DrainSize
  38. buffer := buf.New()
  39. defer buffer.Release()
  40. ivLen := account.Cipher.IVSize()
  41. var iv []byte
  42. if ivLen > 0 {
  43. if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
  44. readSizeRemain -= int(buffer.Len())
  45. DrainConnN(reader, readSizeRemain)
  46. return nil, nil, newError("failed to read IV").Base(err)
  47. }
  48. iv = append([]byte(nil), buffer.BytesTo(ivLen)...)
  49. }
  50. r, err := account.Cipher.NewDecryptionReader(account.Key, iv, reader)
  51. if err != nil {
  52. readSizeRemain -= int(buffer.Len())
  53. DrainConnN(reader, readSizeRemain)
  54. return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
  55. }
  56. br := &buf.BufferedReader{Reader: r}
  57. request := &protocol.RequestHeader{
  58. Version: Version,
  59. User: user,
  60. Command: protocol.RequestCommandTCP,
  61. }
  62. readSizeRemain -= int(buffer.Len())
  63. buffer.Clear()
  64. addr, port, err := addrParser.ReadAddressPort(buffer, br)
  65. if err != nil {
  66. readSizeRemain -= int(buffer.Len())
  67. DrainConnN(reader, readSizeRemain)
  68. return nil, nil, newError("failed to read address").Base(err)
  69. }
  70. request.Address = addr
  71. request.Port = port
  72. if request.Address == nil {
  73. readSizeRemain -= int(buffer.Len())
  74. DrainConnN(reader, readSizeRemain)
  75. return nil, nil, newError("invalid remote address.")
  76. }
  77. return request, br, nil
  78. }
  79. func DrainConnN(reader io.Reader, n int) error {
  80. _, err := io.CopyN(ioutil.Discard, reader, int64(n))
  81. return err
  82. }
  83. // WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body.
  84. func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
  85. user := request.User
  86. account := user.Account.(*MemoryAccount)
  87. var iv []byte
  88. if account.Cipher.IVSize() > 0 {
  89. iv = make([]byte, account.Cipher.IVSize())
  90. common.Must2(rand.Read(iv))
  91. if err := buf.WriteAllBytes(writer, iv); err != nil {
  92. return nil, newError("failed to write IV")
  93. }
  94. }
  95. w, err := account.Cipher.NewEncryptionWriter(account.Key, iv, writer)
  96. if err != nil {
  97. return nil, newError("failed to create encoding stream").Base(err).AtError()
  98. }
  99. header := buf.New()
  100. if err := addrParser.WriteAddressPort(header, request.Address, request.Port); err != nil {
  101. return nil, newError("failed to write address").Base(err)
  102. }
  103. if err := w.WriteMultiBuffer(buf.MultiBuffer{header}); err != nil {
  104. return nil, newError("failed to write header").Base(err)
  105. }
  106. return w, nil
  107. }
  108. func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, error) {
  109. account := user.Account.(*MemoryAccount)
  110. var iv []byte
  111. if account.Cipher.IVSize() > 0 {
  112. iv = make([]byte, account.Cipher.IVSize())
  113. if _, err := io.ReadFull(reader, iv); err != nil {
  114. return nil, newError("failed to read IV").Base(err)
  115. }
  116. }
  117. return account.Cipher.NewDecryptionReader(account.Key, iv, reader)
  118. }
  119. func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
  120. user := request.User
  121. account := user.Account.(*MemoryAccount)
  122. var iv []byte
  123. if account.Cipher.IVSize() > 0 {
  124. iv = make([]byte, account.Cipher.IVSize())
  125. common.Must2(rand.Read(iv))
  126. if err := buf.WriteAllBytes(writer, iv); err != nil {
  127. return nil, newError("failed to write IV.").Base(err)
  128. }
  129. }
  130. return account.Cipher.NewEncryptionWriter(account.Key, iv, writer)
  131. }
  132. func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
  133. user := request.User
  134. account := user.Account.(*MemoryAccount)
  135. buffer := buf.New()
  136. ivLen := account.Cipher.IVSize()
  137. if ivLen > 0 {
  138. common.Must2(buffer.ReadFullFrom(rand.Reader, ivLen))
  139. }
  140. if err := addrParser.WriteAddressPort(buffer, request.Address, request.Port); err != nil {
  141. return nil, newError("failed to write address").Base(err)
  142. }
  143. buffer.Write(payload)
  144. if err := account.Cipher.EncodePacket(account.Key, buffer); err != nil {
  145. return nil, newError("failed to encrypt UDP payload").Base(err)
  146. }
  147. return buffer, nil
  148. }
  149. func DecodeUDPPacket(user *protocol.MemoryUser, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
  150. account := user.Account.(*MemoryAccount)
  151. var iv []byte
  152. if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 {
  153. // Keep track of IV as it gets removed from payload in DecodePacket.
  154. iv = make([]byte, account.Cipher.IVSize())
  155. copy(iv, payload.BytesTo(account.Cipher.IVSize()))
  156. }
  157. if err := account.Cipher.DecodePacket(account.Key, payload); err != nil {
  158. return nil, nil, newError("failed to decrypt UDP payload").Base(err)
  159. }
  160. request := &protocol.RequestHeader{
  161. Version: Version,
  162. User: user,
  163. Command: protocol.RequestCommandUDP,
  164. }
  165. payload.SetByte(0, payload.Byte(0)&0x0F)
  166. addr, port, err := addrParser.ReadAddressPort(nil, payload)
  167. if err != nil {
  168. return nil, nil, newError("failed to parse address").Base(err)
  169. }
  170. request.Address = addr
  171. request.Port = port
  172. return request, payload, nil
  173. }
  174. type UDPReader struct {
  175. Reader io.Reader
  176. User *protocol.MemoryUser
  177. }
  178. func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  179. buffer := buf.New()
  180. _, err := buffer.ReadFrom(v.Reader)
  181. if err != nil {
  182. buffer.Release()
  183. return nil, err
  184. }
  185. u, payload, err := DecodeUDPPacket(v.User, buffer)
  186. if err != nil {
  187. buffer.Release()
  188. return nil, err
  189. }
  190. payload.UDP = &net.UDPAddr{
  191. IP: u.Address.IP(),
  192. Port: int(u.Port),
  193. }
  194. return buf.MultiBuffer{payload}, nil
  195. }
  196. type UDPWriter struct {
  197. Writer io.Writer
  198. Request *protocol.RequestHeader
  199. }
  200. func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
  201. for {
  202. mb2, b := buf.SplitFirst(mb)
  203. mb = mb2
  204. if b == nil {
  205. break
  206. }
  207. var packet *buf.Buffer
  208. var err error
  209. if b.UDP != nil {
  210. request := &protocol.RequestHeader{
  211. User: w.Request.User,
  212. Address: net.IPAddress(b.UDP.IP),
  213. Port: net.Port(b.UDP.Port),
  214. }
  215. packet, err = EncodeUDPPacket(request, b.Bytes())
  216. } else {
  217. packet, err = EncodeUDPPacket(w.Request, b.Bytes())
  218. }
  219. b.Release()
  220. if err != nil {
  221. buf.ReleaseMulti(mb)
  222. return err
  223. }
  224. _, err = w.Writer.Write(packet.Bytes())
  225. packet.Release()
  226. if err != nil {
  227. buf.ReleaseMulti(mb)
  228. return err
  229. }
  230. }
  231. return nil
  232. }