dns.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. package sniff
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "os"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing/common"
  11. "github.com/sagernet/sing/common/buf"
  12. M "github.com/sagernet/sing/common/metadata"
  13. "github.com/sagernet/sing/common/task"
  14. mDNS "github.com/miekg/dns"
  15. )
  16. func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
  17. var length uint16
  18. err := binary.Read(reader, binary.BigEndian, &length)
  19. if err != nil {
  20. return nil, err
  21. }
  22. if length == 0 {
  23. return nil, os.ErrInvalid
  24. }
  25. _buffer := buf.StackNewSize(int(length))
  26. defer common.KeepAlive(_buffer)
  27. buffer := common.Dup(_buffer)
  28. defer buffer.Release()
  29. readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
  30. var readTask task.Group
  31. readTask.Append0(func(ctx context.Context) error {
  32. return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
  33. })
  34. err = readTask.Run(readCtx)
  35. cancel()
  36. if err != nil {
  37. return nil, err
  38. }
  39. return DomainNameQuery(readCtx, buffer.Bytes())
  40. }
  41. func DomainNameQuery(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
  42. var msg mDNS.Msg
  43. err := msg.Unpack(packet)
  44. if err != nil {
  45. return nil, err
  46. }
  47. if len(msg.Question) == 0 || msg.Question[0].Qclass != mDNS.ClassINET || !M.IsDomainName(msg.Question[0].Name) {
  48. return nil, os.ErrInvalid
  49. }
  50. return &adapter.InboundContext{Protocol: C.ProtocolDNS}, nil
  51. }