dns.go 6.4 KB


  1. package outbound
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "net"
  6. "os"
  7. "github.com/sagernet/sing-box/adapter"
  8. C "github.com/sagernet/sing-box/constant"
  9. "github.com/sagernet/sing-dns"
  10. "github.com/sagernet/sing/common"
  11. "github.com/sagernet/sing/common/buf"
  12. "github.com/sagernet/sing/common/bufio"
  13. "github.com/sagernet/sing/common/canceler"
  14. M "github.com/sagernet/sing/common/metadata"
  15. N "github.com/sagernet/sing/common/network"
  16. "github.com/sagernet/sing/common/task"
  17. mDNS "github.com/miekg/dns"
  18. )
  19. var _ adapter.Outbound = (*DNS)(nil)
  20. type DNS struct {
  21. myOutboundAdapter
  22. }
  23. func NewDNS(router adapter.Router, tag string) *DNS {
  24. return &DNS{
  25. myOutboundAdapter{
  26. protocol: C.TypeDNS,
  27. network: []string{N.NetworkTCP, N.NetworkUDP},
  28. router: router,
  29. tag: tag,
  30. },
  31. }
  32. }
  33. func (d *DNS) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  34. return nil, os.ErrInvalid
  35. }
  36. func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  37. return nil, os.ErrInvalid
  38. }
  39. // Deprecated
  40. func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
  41. metadata.Destination = M.Socksaddr{}
  42. defer conn.Close()
  43. for {
  44. err := d.handleConnection(ctx, conn, metadata)
  45. if err != nil {
  46. return err
  47. }
  48. }
  49. }
  50. func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
  51. var queryLength uint16
  52. err := binary.Read(conn, binary.BigEndian, &queryLength)
  53. if err != nil {
  54. return err
  55. }
  56. if queryLength == 0 {
  57. return dns.RCodeFormatError
  58. }
  59. buffer := buf.NewSize(int(queryLength))
  60. defer buffer.Release()
  61. _, err = buffer.ReadFullFrom(conn, int(queryLength))
  62. if err != nil {
  63. return err
  64. }
  65. var message mDNS.Msg
  66. err = message.Unpack(buffer.Bytes())
  67. if err != nil {
  68. return err
  69. }
  70. metadataInQuery := metadata
  71. go func() error {
  72. response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
  73. if err != nil {
  74. return err
  75. }
  76. responseBuffer := buf.NewPacket()
  77. defer responseBuffer.Release()
  78. responseBuffer.Resize(2, 0)
  79. n, err := response.PackBuffer(responseBuffer.FreeBytes())
  80. if err != nil {
  81. return err
  82. }
  83. responseBuffer.Truncate(len(n))
  84. binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
  85. _, err = conn.Write(responseBuffer.Bytes())
  86. return err
  87. }()
  88. return nil
  89. }
  90. // Deprecated
  91. func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
  92. metadata.Destination = M.Socksaddr{}
  93. var reader N.PacketReader = conn
  94. var counters []N.CountFunc
  95. var cachedPackets []*N.PacketBuffer
  96. for {
  97. reader, counters = N.UnwrapCountPacketReader(reader, counters)
  98. if cachedReader, isCached := reader.(N.CachedPacketReader); isCached {
  99. packet := cachedReader.ReadCachedPacket()
  100. if packet != nil {
  101. cachedPackets = append(cachedPackets, packet)
  102. continue
  103. }
  104. }
  105. if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
  106. readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
  107. return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata)
  108. }
  109. break
  110. }
  111. fastClose, cancel := common.ContextWithCancelCause(ctx)
  112. timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
  113. var group task.Group
  114. group.Append0(func(_ context.Context) error {
  115. for {
  116. var message mDNS.Msg
  117. var destination M.Socksaddr
  118. var err error
  119. if len(cachedPackets) > 0 {
  120. packet := cachedPackets[0]
  121. cachedPackets = cachedPackets[1:]
  122. for _, counter := range counters {
  123. counter(int64(packet.Buffer.Len()))
  124. }
  125. err = message.Unpack(packet.Buffer.Bytes())
  126. packet.Buffer.Release()
  127. if err != nil {
  128. cancel(err)
  129. return err
  130. }
  131. destination = packet.Destination
  132. } else {
  133. buffer := buf.NewPacket()
  134. destination, err = conn.ReadPacket(buffer)
  135. if err != nil {
  136. buffer.Release()
  137. cancel(err)
  138. return err
  139. }
  140. for _, counter := range counters {
  141. counter(int64(buffer.Len()))
  142. }
  143. err = message.Unpack(buffer.Bytes())
  144. buffer.Release()
  145. if err != nil {
  146. cancel(err)
  147. return err
  148. }
  149. timeout.Update()
  150. }
  151. metadataInQuery := metadata
  152. go func() error {
  153. response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
  154. if err != nil {
  155. cancel(err)
  156. return err
  157. }
  158. timeout.Update()
  159. responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
  160. if err != nil {
  161. cancel(err)
  162. return err
  163. }
  164. err = conn.WritePacket(responseBuffer, destination)
  165. if err != nil {
  166. cancel(err)
  167. }
  168. return err
  169. }()
  170. }
  171. })
  172. group.Cleanup(func() {
  173. conn.Close()
  174. })
  175. return group.Run(fastClose)
  176. }
  177. func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
  178. fastClose, cancel := common.ContextWithCancelCause(ctx)
  179. timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
  180. var group task.Group
  181. group.Append0(func(_ context.Context) error {
  182. for {
  183. var (
  184. message mDNS.Msg
  185. destination M.Socksaddr
  186. err error
  187. buffer *buf.Buffer
  188. )
  189. if len(cached) > 0 {
  190. packet := cached[0]
  191. cached = cached[1:]
  192. for _, counter := range readCounters {
  193. counter(int64(packet.Buffer.Len()))
  194. }
  195. err = message.Unpack(packet.Buffer.Bytes())
  196. packet.Buffer.Release()
  197. if err != nil {
  198. cancel(err)
  199. return err
  200. }
  201. destination = packet.Destination
  202. } else {
  203. buffer, destination, err = readWaiter.WaitReadPacket()
  204. if err != nil {
  205. cancel(err)
  206. return err
  207. }
  208. for _, counter := range readCounters {
  209. counter(int64(buffer.Len()))
  210. }
  211. err = message.Unpack(buffer.Bytes())
  212. buffer.Release()
  213. if err != nil {
  214. cancel(err)
  215. return err
  216. }
  217. timeout.Update()
  218. }
  219. metadataInQuery := metadata
  220. go func() error {
  221. response, err := d.router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
  222. if err != nil {
  223. cancel(err)
  224. return err
  225. }
  226. timeout.Update()
  227. responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
  228. if err != nil {
  229. cancel(err)
  230. return err
  231. }
  232. err = conn.WritePacket(responseBuffer, destination)
  233. if err != nil {
  234. cancel(err)
  235. }
  236. return err
  237. }()
  238. }
  239. })
  240. group.Cleanup(func() {
  241. conn.Close()
  242. })
  243. return group.Run(fastClose)
  244. }