xplus.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package hysteria
  2. import (
  3. "crypto/sha256"
  4. "math/rand"
  5. "net"
  6. "sync"
  7. "time"
  8. "github.com/sagernet/sing/common"
  9. "github.com/sagernet/sing/common/buf"
  10. "github.com/sagernet/sing/common/bufio"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. )
  15. const xplusSaltLen = 16
  16. var errInalidPacket = E.New("invalid packet")
  17. func NewXPlusPacketConn(conn net.PacketConn, key []byte) net.PacketConn {
  18. vectorisedWriter, isVectorised := bufio.CreateVectorisedPacketWriter(conn)
  19. if isVectorised {
  20. return &VectorisedXPlusConn{
  21. XPlusPacketConn: XPlusPacketConn{
  22. PacketConn: conn,
  23. key: key,
  24. rand: rand.New(rand.NewSource(time.Now().UnixNano())),
  25. },
  26. writer: vectorisedWriter,
  27. }
  28. } else {
  29. return &XPlusPacketConn{
  30. PacketConn: conn,
  31. key: key,
  32. rand: rand.New(rand.NewSource(time.Now().UnixNano())),
  33. }
  34. }
  35. }
  36. type XPlusPacketConn struct {
  37. net.PacketConn
  38. key []byte
  39. randAccess sync.Mutex
  40. rand *rand.Rand
  41. }
  42. func (c *XPlusPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  43. n, addr, err = c.PacketConn.ReadFrom(p)
  44. if err != nil {
  45. return
  46. } else if n < xplusSaltLen {
  47. return 0, nil, errInalidPacket
  48. }
  49. key := sha256.Sum256(append(c.key, p[:xplusSaltLen]...))
  50. for i := range p[xplusSaltLen:] {
  51. p[i] = p[xplusSaltLen+i] ^ key[i%sha256.Size]
  52. }
  53. n -= xplusSaltLen
  54. return
  55. }
  56. func (c *XPlusPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  57. // can't use unsafe buffer on WriteTo
  58. buffer := buf.NewSize(len(p) + xplusSaltLen)
  59. defer buffer.Release()
  60. salt := buffer.Extend(xplusSaltLen)
  61. c.randAccess.Lock()
  62. _, _ = c.rand.Read(salt)
  63. c.randAccess.Unlock()
  64. key := sha256.Sum256(append(c.key, salt...))
  65. for i := range p {
  66. common.Must(buffer.WriteByte(p[i] ^ key[i%sha256.Size]))
  67. }
  68. return c.PacketConn.WriteTo(buffer.Bytes(), addr)
  69. }
  70. func (c *XPlusPacketConn) Upstream() any {
  71. return c.PacketConn
  72. }
  73. type VectorisedXPlusConn struct {
  74. XPlusPacketConn
  75. writer N.VectorisedPacketWriter
  76. }
  77. func (c *VectorisedXPlusConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
  78. header := buf.NewSize(xplusSaltLen)
  79. defer header.Release()
  80. salt := header.Extend(xplusSaltLen)
  81. c.randAccess.Lock()
  82. _, _ = c.rand.Read(salt)
  83. c.randAccess.Unlock()
  84. key := sha256.Sum256(append(c.key, salt...))
  85. for i := range p {
  86. p[i] ^= key[i%sha256.Size]
  87. }
  88. return bufio.WriteVectorisedPacket(c.writer, [][]byte{header.Bytes(), p}, M.SocksaddrFromNet(addr))
  89. }
  90. func (c *VectorisedXPlusConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error {
  91. header := buf.NewSize(xplusSaltLen)
  92. salt := header.Extend(xplusSaltLen)
  93. c.randAccess.Lock()
  94. _, _ = c.rand.Read(salt)
  95. c.randAccess.Unlock()
  96. key := sha256.Sum256(append(c.key, salt...))
  97. var index int
  98. for _, buffer := range buffers {
  99. data := buffer.Bytes()
  100. for i := range data {
  101. data[i] ^= key[index%sha256.Size]
  102. index++
  103. }
  104. }
  105. buffers = append([]*buf.Buffer{header}, buffers...)
  106. return c.writer.WriteVectorisedPacket(buffers, destination)
  107. }