dns.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. package sniff
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "io"
  6. "os"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing/common"
  11. "github.com/sagernet/sing/common/buf"
  12. "github.com/sagernet/sing/common/task"
  13. "golang.org/x/net/dns/dnsmessage"
  14. )
  15. func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
  16. var length uint16
  17. err := binary.Read(reader, binary.BigEndian, &length)
  18. if err != nil {
  19. return nil, err
  20. }
  21. if length > 512 {
  22. return nil, os.ErrInvalid
  23. }
  24. _buffer := buf.StackNewSize(int(length))
  25. defer common.KeepAlive(_buffer)
  26. buffer := common.Dup(_buffer)
  27. defer buffer.Release()
  28. readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
  29. err = task.Run(readCtx, func() error {
  30. return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
  31. })
  32. cancel()
  33. if err != nil {
  34. return nil, err
  35. }
  36. return DomainNameQuery(readCtx, buffer.Bytes())
  37. }
  38. func DomainNameQuery(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
  39. var parser dnsmessage.Parser
  40. _, err := parser.Start(packet)
  41. if err != nil {
  42. return nil, err
  43. }
  44. question, err := parser.Question()
  45. if err != nil {
  46. return nil, os.ErrInvalid
  47. }
  48. domain := question.Name.String()
  49. if question.Class == dnsmessage.ClassINET && (question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA) && IsDomainName(domain) {
  50. return &adapter.InboundContext{Protocol: C.ProtocolDNS, Domain: domain}, nil
  51. }
  52. return nil, os.ErrInvalid
  53. }