handle.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package dns
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "net"
  6. "github.com/sagernet/sing-box/adapter"
  7. C "github.com/sagernet/sing-box/constant"
  8. "github.com/sagernet/sing-box/dns"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/bufio"
  12. "github.com/sagernet/sing/common/canceler"
  13. M "github.com/sagernet/sing/common/metadata"
  14. N "github.com/sagernet/sing/common/network"
  15. "github.com/sagernet/sing/common/task"
  16. mDNS "github.com/miekg/dns"
  17. )
  18. func HandleStreamDNSRequest(ctx context.Context, router adapter.DNSRouter, conn net.Conn, metadata adapter.InboundContext) error {
  19. var queryLength uint16
  20. err := binary.Read(conn, binary.BigEndian, &queryLength)
  21. if err != nil {
  22. return err
  23. }
  24. if queryLength == 0 {
  25. return dns.RcodeFormatError
  26. }
  27. buffer := buf.NewSize(int(queryLength))
  28. defer buffer.Release()
  29. _, err = buffer.ReadFullFrom(conn, int(queryLength))
  30. if err != nil {
  31. return err
  32. }
  33. var message mDNS.Msg
  34. err = message.Unpack(buffer.Bytes())
  35. if err != nil {
  36. return err
  37. }
  38. metadataInQuery := metadata
  39. go func() error {
  40. response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
  41. if err != nil {
  42. conn.Close()
  43. return err
  44. }
  45. responseLength := response.Len()
  46. responseBuffer := buf.NewSize(3 + responseLength)
  47. defer responseBuffer.Release()
  48. responseBuffer.Resize(2, 0)
  49. n, err := response.PackBuffer(responseBuffer.FreeBytes())
  50. if err != nil {
  51. return err
  52. }
  53. responseBuffer.Truncate(len(n))
  54. binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
  55. _, err = conn.Write(responseBuffer.Bytes())
  56. return err
  57. }()
  58. return nil
  59. }
  60. func NewDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
  61. metadata.Destination = M.Socksaddr{}
  62. var reader N.PacketReader = conn
  63. var counters []N.CountFunc
  64. cachedPackets = common.Reverse(cachedPackets)
  65. for {
  66. reader, counters = N.UnwrapCountPacketReader(reader, counters)
  67. if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
  68. packet := cachedReader.ReadCachedPacket()
  69. if packet != nil {
  70. cachedPackets = append(cachedPackets, packet)
  71. continue
  72. }
  73. }
  74. if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
  75. readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
  76. return newDNSPacketConnection(ctx, router, conn, readWaiter, counters, cachedPackets, metadata)
  77. }
  78. break
  79. }
  80. fastClose, cancel := common.ContextWithCancelCause(ctx)
  81. timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
  82. var group task.Group
  83. group.Append0(func(_ context.Context) error {
  84. for {
  85. var message mDNS.Msg
  86. var destination M.Socksaddr
  87. var err error
  88. if len(cachedPackets) > 0 {
  89. packet := cachedPackets[0]
  90. cachedPackets = cachedPackets[1:]
  91. for _, counter := range counters {
  92. counter(int64(packet.Buffer.Len()))
  93. }
  94. err = message.Unpack(packet.Buffer.Bytes())
  95. packet.Buffer.Release()
  96. if err != nil {
  97. cancel(err)
  98. return err
  99. }
  100. destination = packet.Destination
  101. } else {
  102. buffer := buf.NewPacket()
  103. destination, err = conn.ReadPacket(buffer)
  104. if err != nil {
  105. buffer.Release()
  106. cancel(err)
  107. return err
  108. }
  109. for _, counter := range counters {
  110. counter(int64(buffer.Len()))
  111. }
  112. err = message.Unpack(buffer.Bytes())
  113. buffer.Release()
  114. if err != nil {
  115. cancel(err)
  116. return err
  117. }
  118. timeout.Update()
  119. }
  120. metadataInQuery := metadata
  121. go func() error {
  122. response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
  123. if err != nil {
  124. cancel(err)
  125. return err
  126. }
  127. timeout.Update()
  128. responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
  129. if err != nil {
  130. cancel(err)
  131. return err
  132. }
  133. err = conn.WritePacket(responseBuffer, destination)
  134. if err != nil {
  135. cancel(err)
  136. }
  137. return err
  138. }()
  139. }
  140. })
  141. group.Cleanup(func() {
  142. conn.Close()
  143. })
  144. return group.Run(fastClose)
  145. }
  146. func newDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
  147. fastClose, cancel := common.ContextWithCancelCause(ctx)
  148. timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
  149. var group task.Group
  150. group.Append0(func(_ context.Context) error {
  151. for {
  152. var (
  153. message mDNS.Msg
  154. destination M.Socksaddr
  155. err error
  156. buffer *buf.Buffer
  157. )
  158. if len(cached) > 0 {
  159. packet := cached[0]
  160. cached = cached[1:]
  161. for _, counter := range readCounters {
  162. counter(int64(packet.Buffer.Len()))
  163. }
  164. err = message.Unpack(packet.Buffer.Bytes())
  165. packet.Buffer.Release()
  166. destination = packet.Destination
  167. N.PutPacketBuffer(packet)
  168. if err != nil {
  169. cancel(err)
  170. return err
  171. }
  172. } else {
  173. buffer, destination, err = readWaiter.WaitReadPacket()
  174. if err != nil {
  175. cancel(err)
  176. return err
  177. }
  178. for _, counter := range readCounters {
  179. counter(int64(buffer.Len()))
  180. }
  181. err = message.Unpack(buffer.Bytes())
  182. buffer.Release()
  183. if err != nil {
  184. cancel(err)
  185. return err
  186. }
  187. timeout.Update()
  188. }
  189. metadataInQuery := metadata
  190. go func() error {
  191. response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
  192. if err != nil {
  193. cancel(err)
  194. return err
  195. }
  196. timeout.Update()
  197. responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
  198. if err != nil {
  199. cancel(err)
  200. return err
  201. }
  202. err = conn.WritePacket(responseBuffer, destination)
  203. if err != nil {
  204. cancel(err)
  205. }
  206. return err
  207. }()
  208. }
  209. })
  210. group.Cleanup(func() {
  211. conn.Close()
  212. })
  213. return group.Run(fastClose)
  214. }