dns.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. package sniff
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "os"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing/common"
  11. "github.com/sagernet/sing/common/buf"
  12. M "github.com/sagernet/sing/common/metadata"
  13. "github.com/sagernet/sing/common/task"
  14. mDNS "github.com/miekg/dns"
  15. )
  16. func StreamDomainNameQuery(readCtx context.Context, metadata *adapter.InboundContext, reader io.Reader) error {
  17. var length uint16
  18. err := binary.Read(reader, binary.BigEndian, &length)
  19. if err != nil {
  20. return os.ErrInvalid
  21. }
  22. if length == 0 {
  23. return os.ErrInvalid
  24. }
  25. buffer := buf.NewSize(int(length))
  26. defer buffer.Release()
  27. readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
  28. var readTask task.Group
  29. readTask.Append0(func(ctx context.Context) error {
  30. return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
  31. })
  32. err = readTask.Run(readCtx)
  33. cancel()
  34. if err != nil {
  35. return err
  36. }
  37. return DomainNameQuery(readCtx, metadata, buffer.Bytes())
  38. }
  39. func DomainNameQuery(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
  40. var msg mDNS.Msg
  41. err := msg.Unpack(packet)
  42. if err != nil {
  43. return err
  44. }
  45. if len(msg.Question) == 0 || msg.Question[0].Qclass != mDNS.ClassINET || !M.IsDomainName(msg.Question[0].Name) {
  46. return os.ErrInvalid
  47. }
  48. metadata.Protocol = C.ProtocolDNS
  49. return nil
  50. }