conn.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package tf
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "math/rand"
  7. "net"
  8. "strings"
  9. "time"
  10. C "github.com/sagernet/sing-box/constant"
  11. N "github.com/sagernet/sing/common/network"
  12. "golang.org/x/net/publicsuffix"
  13. )
  14. type Conn struct {
  15. net.Conn
  16. tcpConn *net.TCPConn
  17. ctx context.Context
  18. firstPacketWritten bool
  19. splitPacket bool
  20. splitRecord bool
  21. fallbackDelay time.Duration
  22. }
  23. func NewConn(conn net.Conn, ctx context.Context, splitPacket bool, splitRecord bool, fallbackDelay time.Duration) *Conn {
  24. if fallbackDelay == 0 {
  25. fallbackDelay = C.TLSFragmentFallbackDelay
  26. }
  27. tcpConn, _ := N.UnwrapReader(conn).(*net.TCPConn)
  28. return &Conn{
  29. Conn: conn,
  30. tcpConn: tcpConn,
  31. ctx: ctx,
  32. splitPacket: splitPacket,
  33. splitRecord: splitRecord,
  34. fallbackDelay: fallbackDelay,
  35. }
  36. }
  37. func (c *Conn) Write(b []byte) (n int, err error) {
  38. if !c.firstPacketWritten {
  39. defer func() {
  40. c.firstPacketWritten = true
  41. }()
  42. serverName := IndexTLSServerName(b)
  43. if serverName != nil {
  44. if c.splitPacket {
  45. if c.tcpConn != nil {
  46. err = c.tcpConn.SetNoDelay(true)
  47. if err != nil {
  48. return
  49. }
  50. }
  51. }
  52. splits := strings.Split(serverName.ServerName, ".")
  53. currentIndex := serverName.Index
  54. if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" {
  55. splits = splits[:len(splits)-strings.Count(serverName.ServerName, ".")]
  56. }
  57. if len(splits) > 1 && splits[0] == "..." {
  58. currentIndex += len(splits[0]) + 1
  59. splits = splits[1:]
  60. }
  61. var splitIndexes []int
  62. for i, split := range splits {
  63. splitAt := rand.Intn(len(split))
  64. splitIndexes = append(splitIndexes, currentIndex+splitAt)
  65. currentIndex += len(split)
  66. if i != len(splits)-1 {
  67. currentIndex++
  68. }
  69. }
  70. var buffer bytes.Buffer
  71. for i := 0; i <= len(splitIndexes); i++ {
  72. var payload []byte
  73. if i == 0 {
  74. payload = b[:splitIndexes[i]]
  75. if c.splitRecord {
  76. payload = payload[recordLayerHeaderLen:]
  77. }
  78. } else if i == len(splitIndexes) {
  79. payload = b[splitIndexes[i-1]:]
  80. } else {
  81. payload = b[splitIndexes[i-1]:splitIndexes[i]]
  82. }
  83. if c.splitRecord {
  84. if c.splitPacket {
  85. buffer.Reset()
  86. }
  87. payloadLen := uint16(len(payload))
  88. buffer.Write(b[:3])
  89. binary.Write(&buffer, binary.BigEndian, payloadLen)
  90. buffer.Write(payload)
  91. if c.splitPacket {
  92. payload = buffer.Bytes()
  93. }
  94. }
  95. if c.splitPacket {
  96. if c.tcpConn != nil && i != len(splitIndexes) {
  97. err = writeAndWaitAck(c.ctx, c.tcpConn, payload, c.fallbackDelay)
  98. if err != nil {
  99. return
  100. }
  101. } else {
  102. _, err = c.Conn.Write(payload)
  103. if err != nil {
  104. return
  105. }
  106. if i != len(splitIndexes) {
  107. time.Sleep(c.fallbackDelay)
  108. }
  109. }
  110. }
  111. }
  112. if c.splitRecord && !c.splitPacket {
  113. _, err = c.Conn.Write(buffer.Bytes())
  114. if err != nil {
  115. return
  116. }
  117. }
  118. if c.tcpConn != nil {
  119. err = c.tcpConn.SetNoDelay(false)
  120. if err != nil {
  121. return
  122. }
  123. }
  124. return len(b), nil
  125. }
  126. }
  127. return c.Conn.Write(b)
  128. }
  129. func (c *Conn) ReaderReplaceable() bool {
  130. return true
  131. }
  132. func (c *Conn) WriterReplaceable() bool {
  133. return c.firstPacketWritten
  134. }
  135. func (c *Conn) Upstream() any {
  136. return c.Conn
  137. }