handle.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package dns
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "net"
  6. "github.com/sagernet/sing-box/adapter"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing-dns"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. "github.com/sagernet/sing/common/canceler"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/task"
  16. mDNS "github.com/miekg/dns"
  17. )
  18. func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net.Conn, metadata adapter.InboundContext) error {
  19. var queryLength uint16
  20. err := binary.Read(conn, binary.BigEndian, &queryLength)
  21. if err != nil {
  22. return err
  23. }
  24. if queryLength == 0 {
  25. return dns.RCodeFormatError
  26. }
  27. buffer := buf.NewSize(int(queryLength))
  28. defer buffer.Release()
  29. _, err = buffer.ReadFullFrom(conn, int(queryLength))
  30. if err != nil {
  31. return err
  32. }
  33. var message mDNS.Msg
  34. err = message.Unpack(buffer.Bytes())
  35. if err != nil {
  36. return err
  37. }
  38. metadataInQuery := metadata
  39. go func() error {
  40. response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
  41. if err != nil {
  42. return err
  43. }
  44. responseBuffer := buf.NewPacket()
  45. defer responseBuffer.Release()
  46. responseBuffer.Resize(2, 0)
  47. n, err := response.PackBuffer(responseBuffer.FreeBytes())
  48. if err != nil {
  49. return err
  50. }
  51. responseBuffer.Truncate(len(n))
  52. binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
  53. _, err = conn.Write(responseBuffer.Bytes())
  54. return err
  55. }()
  56. return nil
  57. }
  58. func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
  59. metadata.Destination = M.Socksaddr{}
  60. var reader N.PacketReader = conn
  61. var counters []N.CountFunc
  62. cachedPackets = common.Reverse(cachedPackets)
  63. for {
  64. reader, counters = N.UnwrapCountPacketReader(reader, counters)
  65. if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
  66. packet := cachedReader.ReadCachedPacket()
  67. if packet != nil {
  68. cachedPackets = append(cachedPackets, packet)
  69. continue
  70. }
  71. }
  72. if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
  73. readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
  74. return newDNSPacketConnection(ctx, router, conn, readWaiter, counters, cachedPackets, metadata)
  75. }
  76. break
  77. }
  78. fastClose, cancel := common.ContextWithCancelCause(ctx)
  79. timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
  80. var group task.Group
  81. group.Append0(func(_ context.Context) error {
  82. for {
  83. var message mDNS.Msg
  84. var destination M.Socksaddr
  85. var err error
  86. if len(cachedPackets) > 0 {
  87. packet := cachedPackets[0]
  88. cachedPackets = cachedPackets[1:]
  89. for _, counter := range counters {
  90. counter(int64(packet.Buffer.Len()))
  91. }
  92. err = message.Unpack(packet.Buffer.Bytes())
  93. packet.Buffer.Release()
  94. if err != nil {
  95. cancel(err)
  96. return err
  97. }
  98. destination = packet.Destination
  99. } else {
  100. buffer := buf.NewPacket()
  101. destination, err = conn.ReadPacket(buffer)
  102. if err != nil {
  103. buffer.Release()
  104. cancel(err)
  105. return err
  106. }
  107. for _, counter := range counters {
  108. counter(int64(buffer.Len()))
  109. }
  110. err = message.Unpack(buffer.Bytes())
  111. buffer.Release()
  112. if err != nil {
  113. cancel(err)
  114. return err
  115. }
  116. timeout.Update()
  117. }
  118. metadataInQuery := metadata
  119. go func() error {
  120. response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
  121. if err != nil {
  122. cancel(err)
  123. return err
  124. }
  125. timeout.Update()
  126. responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
  127. if err != nil {
  128. cancel(err)
  129. return err
  130. }
  131. err = conn.WritePacket(responseBuffer, destination)
  132. if err != nil {
  133. cancel(err)
  134. }
  135. return err
  136. }()
  137. }
  138. })
  139. group.Cleanup(func() {
  140. conn.Close()
  141. })
  142. return group.Run(fastClose)
  143. }
  144. func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
  145. fastClose, cancel := common.ContextWithCancelCause(ctx)
  146. timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
  147. var group task.Group
  148. group.Append0(func(_ context.Context) error {
  149. for {
  150. var (
  151. message mDNS.Msg
  152. destination M.Socksaddr
  153. err error
  154. buffer *buf.Buffer
  155. )
  156. if len(cached) > 0 {
  157. packet := cached[0]
  158. cached = cached[1:]
  159. for _, counter := range readCounters {
  160. counter(int64(packet.Buffer.Len()))
  161. }
  162. err = message.Unpack(packet.Buffer.Bytes())
  163. packet.Buffer.Release()
  164. destination = packet.Destination
  165. N.PutPacketBuffer(packet)
  166. if err != nil {
  167. cancel(err)
  168. return err
  169. }
  170. } else {
  171. buffer, destination, err = readWaiter.WaitReadPacket()
  172. if err != nil {
  173. cancel(err)
  174. return err
  175. }
  176. for _, counter := range readCounters {
  177. counter(int64(buffer.Len()))
  178. }
  179. err = message.Unpack(buffer.Bytes())
  180. buffer.Release()
  181. if err != nil {
  182. cancel(err)
  183. return err
  184. }
  185. timeout.Update()
  186. }
  187. metadataInQuery := metadata
  188. go func() error {
  189. response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
  190. if err != nil {
  191. cancel(err)
  192. return err
  193. }
  194. timeout.Update()
  195. responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
  196. if err != nil {
  197. cancel(err)
  198. return err
  199. }
  200. err = conn.WritePacket(responseBuffer, destination)
  201. if err != nil {
  202. cancel(err)
  203. }
  204. return err
  205. }()
  206. }
  207. })
  208. group.Cleanup(func() {
  209. conn.Close()
  210. })
  211. return group.Run(fastClose)
  212. }