dns.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package inbound
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "net"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/log"
  9. "github.com/sagernet/sing/common"
  10. "github.com/sagernet/sing/common/buf"
  11. N "github.com/sagernet/sing/common/network"
  12. "golang.org/x/net/dns/dnsmessage"
  13. )
  14. func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error {
  15. ctx = adapter.WithContext(ctx, &metadata)
  16. _buffer := buf.StackNewSize(1024)
  17. defer common.KeepAlive(_buffer)
  18. buffer := common.Dup(_buffer)
  19. defer buffer.Release()
  20. for {
  21. var queryLength uint16
  22. err := binary.Read(conn, binary.BigEndian, &queryLength)
  23. if err != nil {
  24. return err
  25. }
  26. if queryLength > 1024 {
  27. return io.ErrShortBuffer
  28. }
  29. buffer.FullReset()
  30. _, err = buffer.ReadFullFrom(conn, int(queryLength))
  31. if err != nil {
  32. return err
  33. }
  34. var message dnsmessage.Message
  35. err = message.Unpack(buffer.Bytes())
  36. if err != nil {
  37. return err
  38. }
  39. if len(message.Questions) > 0 {
  40. question := message.Questions[0]
  41. metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
  42. logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
  43. }
  44. go func() error {
  45. response, err := router.Exchange(ctx, &message)
  46. if err != nil {
  47. return err
  48. }
  49. _responseBuffer := buf.StackNewSize(1024)
  50. defer common.KeepAlive(_responseBuffer)
  51. responseBuffer := common.Dup(_responseBuffer)
  52. defer responseBuffer.Release()
  53. responseBuffer.Resize(2, 0)
  54. n, err := response.AppendPack(responseBuffer.Index(0))
  55. if err != nil {
  56. return err
  57. }
  58. responseBuffer.Truncate(len(n))
  59. binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
  60. _, err = conn.Write(responseBuffer.Bytes())
  61. return err
  62. }()
  63. }
  64. }
  65. func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger log.ContextLogger, conn N.PacketConn, metadata adapter.InboundContext) error {
  66. ctx = adapter.WithContext(ctx, &metadata)
  67. _buffer := buf.StackNewSize(1024)
  68. defer common.KeepAlive(_buffer)
  69. buffer := common.Dup(_buffer)
  70. defer buffer.Release()
  71. for {
  72. buffer.FullReset()
  73. destination, err := conn.ReadPacket(buffer)
  74. if err != nil {
  75. return err
  76. }
  77. var message dnsmessage.Message
  78. err = message.Unpack(buffer.Bytes())
  79. if err != nil {
  80. return err
  81. }
  82. if len(message.Questions) > 0 {
  83. question := message.Questions[0]
  84. metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
  85. logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
  86. }
  87. go func() error {
  88. response, err := router.Exchange(ctx, &message)
  89. if err != nil {
  90. return err
  91. }
  92. _responseBuffer := buf.StackNewSize(1024)
  93. defer common.KeepAlive(_responseBuffer)
  94. responseBuffer := common.Dup(_responseBuffer)
  95. defer responseBuffer.Release()
  96. n, err := response.AppendPack(responseBuffer.Index(0))
  97. if err != nil {
  98. return err
  99. }
  100. responseBuffer.Truncate(len(n))
  101. err = conn.WritePacket(responseBuffer, destination)
  102. return err
  103. }()
  104. }
  105. }
  106. func formatDNSQuestion(question dnsmessage.Question) string {
  107. return string(question.Name.Data[:question.Name.Length-1]) + " " + question.Type.String()[4:] + " " + question.Class.String()[5:]
  108. }