dns.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package outbound
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "net"
  7. "os"
  8. "sync"
  9. "time"
  10. "github.com/sagernet/sing-box/adapter"
  11. C "github.com/sagernet/sing-box/constant"
  12. "github.com/sagernet/sing-box/log"
  13. "github.com/sagernet/sing/common"
  14. "github.com/sagernet/sing/common/buf"
  15. M "github.com/sagernet/sing/common/metadata"
  16. N "github.com/sagernet/sing/common/network"
  17. "github.com/sagernet/sing/common/task"
  18. "golang.org/x/net/dns/dnsmessage"
  19. )
  20. var _ adapter.Outbound = (*DNS)(nil)
  21. type DNS struct {
  22. myOutboundAdapter
  23. }
  24. func NewDNS(router adapter.Router, logger log.ContextLogger, tag string) *DNS {
  25. return &DNS{
  26. myOutboundAdapter{
  27. protocol: C.TypeDNS,
  28. network: []string{C.NetworkTCP, C.NetworkUDP},
  29. router: router,
  30. logger: logger,
  31. tag: tag,
  32. },
  33. }
  34. }
  35. func (d *DNS) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  36. return nil, os.ErrInvalid
  37. }
  38. func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  39. return nil, os.ErrInvalid
  40. }
  41. func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
  42. defer conn.Close()
  43. ctx = adapter.WithContext(ctx, &metadata)
  44. _buffer := buf.StackNewSize(1024)
  45. defer common.KeepAlive(_buffer)
  46. buffer := common.Dup(_buffer)
  47. defer buffer.Release()
  48. for {
  49. var queryLength uint16
  50. err := binary.Read(conn, binary.BigEndian, &queryLength)
  51. if err != nil {
  52. return err
  53. }
  54. if queryLength > 1024 {
  55. return io.ErrShortBuffer
  56. }
  57. buffer.FullReset()
  58. _, err = buffer.ReadFullFrom(conn, int(queryLength))
  59. if err != nil {
  60. return err
  61. }
  62. var message dnsmessage.Message
  63. err = message.Unpack(buffer.Bytes())
  64. if err != nil {
  65. return err
  66. }
  67. if len(message.Questions) > 0 {
  68. question := message.Questions[0]
  69. metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
  70. d.logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
  71. }
  72. go func() error {
  73. response, err := d.router.Exchange(ctx, &message)
  74. if err != nil {
  75. return err
  76. }
  77. _responseBuffer := buf.StackNewSize(1024)
  78. defer common.KeepAlive(_responseBuffer)
  79. responseBuffer := common.Dup(_responseBuffer)
  80. defer responseBuffer.Release()
  81. responseBuffer.Resize(2, 0)
  82. n, err := response.AppendPack(responseBuffer.Index(0))
  83. if err != nil {
  84. return err
  85. }
  86. responseBuffer.Truncate(len(n))
  87. binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
  88. _, err = conn.Write(responseBuffer.Bytes())
  89. return err
  90. }()
  91. }
  92. }
  93. func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
  94. defer conn.Close()
  95. ctx = adapter.WithContext(ctx, &metadata)
  96. _buffer := buf.StackNewSize(1024)
  97. defer common.KeepAlive(_buffer)
  98. buffer := common.Dup(_buffer)
  99. defer buffer.Release()
  100. var wg sync.WaitGroup
  101. fastClose, cancel := context.WithCancel(ctx)
  102. err := task.Run(fastClose, func() error {
  103. var count int
  104. for {
  105. buffer.FullReset()
  106. destination, err := conn.ReadPacket(buffer)
  107. if err != nil {
  108. return err
  109. }
  110. var message dnsmessage.Message
  111. err = message.Unpack(buffer.Bytes())
  112. if err != nil {
  113. return err
  114. }
  115. if len(message.Questions) > 0 {
  116. question := message.Questions[0]
  117. metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
  118. d.logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
  119. }
  120. wg.Add(1)
  121. go func() error {
  122. defer wg.Done()
  123. response, err := d.router.Exchange(ctx, &message)
  124. if err != nil {
  125. return err
  126. }
  127. _responseBuffer := buf.StackNewSize(1024)
  128. defer common.KeepAlive(_responseBuffer)
  129. responseBuffer := common.Dup(_responseBuffer)
  130. defer responseBuffer.Release()
  131. n, err := response.AppendPack(responseBuffer.Index(0))
  132. if err != nil {
  133. return err
  134. }
  135. responseBuffer.Truncate(len(n))
  136. err = conn.WritePacket(responseBuffer, destination)
  137. return err
  138. }()
  139. count++
  140. if count == 2 {
  141. break
  142. }
  143. }
  144. cancel()
  145. return nil
  146. }, func() error {
  147. timer := time.NewTimer(5 * time.Second)
  148. select {
  149. case <-timer.C:
  150. cancel()
  151. case <-fastClose.Done():
  152. }
  153. return nil
  154. })
  155. wg.Wait()
  156. return err
  157. }
  158. func formatDNSQuestion(question dnsmessage.Question) string {
  159. return string(question.Name.Data[:question.Name.Length-1]) + " " + question.Type.String()[4:] + " " + question.Class.String()[5:]
  160. }