sniff.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. package sniff
  2. import (
  3. "bytes"
  4. "context"
  5. "io"
  6. "net"
  7. "os"
  8. "time"
  9. "github.com/sagernet/sing-box/adapter"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing/common/buf"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. )
  14. type (
  15. StreamSniffer = func(ctx context.Context, reader io.Reader) (*adapter.InboundContext, error)
  16. PacketSniffer = func(ctx context.Context, packet []byte) (*adapter.InboundContext, error)
  17. )
  18. func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) (*adapter.InboundContext, error) {
  19. if timeout == 0 {
  20. timeout = C.ReadPayloadTimeout
  21. }
  22. err := conn.SetReadDeadline(time.Now().Add(timeout))
  23. if err != nil {
  24. return nil, err
  25. }
  26. _, err = buffer.ReadOnceFrom(conn)
  27. err = E.Errors(err, conn.SetReadDeadline(time.Time{}))
  28. if err != nil {
  29. return nil, err
  30. }
  31. var metadata *adapter.InboundContext
  32. for _, sniffer := range sniffers {
  33. metadata, err = sniffer(ctx, bytes.NewReader(buffer.Bytes()))
  34. if err != nil {
  35. continue
  36. }
  37. return metadata, nil
  38. }
  39. return nil, os.ErrInvalid
  40. }
  41. func PeekPacket(ctx context.Context, packet []byte, sniffers ...PacketSniffer) (*adapter.InboundContext, error) {
  42. for _, sniffer := range sniffers {
  43. sniffMetadata, err := sniffer(ctx, packet)
  44. if err != nil {
  45. continue
  46. }
  47. return sniffMetadata, nil
  48. }
  49. return nil, os.ErrInvalid
  50. }