sniff.go 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. package sniff
  2. import (
  3. "bytes"
  4. "context"
  5. "io"
  6. "net"
  7. "time"
  8. "github.com/sagernet/sing-box/adapter"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing/common/buf"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. )
  13. type (
  14. StreamSniffer = func(ctx context.Context, reader io.Reader) (*adapter.InboundContext, error)
  15. PacketSniffer = func(ctx context.Context, packet []byte) (*adapter.InboundContext, error)
  16. )
  17. func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) (*adapter.InboundContext, error) {
  18. if timeout == 0 {
  19. timeout = C.ReadPayloadTimeout
  20. }
  21. err := conn.SetReadDeadline(time.Now().Add(timeout))
  22. if err != nil {
  23. return nil, err
  24. }
  25. _, err = buffer.ReadOnceFrom(conn)
  26. err = E.Errors(err, conn.SetReadDeadline(time.Time{}))
  27. if err != nil {
  28. return nil, err
  29. }
  30. var metadata *adapter.InboundContext
  31. var errors []error
  32. for _, sniffer := range sniffers {
  33. metadata, err = sniffer(ctx, bytes.NewReader(buffer.Bytes()))
  34. if metadata != nil {
  35. return metadata, nil
  36. }
  37. errors = append(errors, err)
  38. }
  39. return nil, E.Errors(errors...)
  40. }
  41. func PeekPacket(ctx context.Context, packet []byte, sniffers ...PacketSniffer) (*adapter.InboundContext, error) {
  42. var errors []error
  43. for _, sniffer := range sniffers {
  44. metadata, err := sniffer(ctx, packet)
  45. if metadata != nil {
  46. return metadata, nil
  47. }
  48. errors = append(errors, err)
  49. }
  50. return nil, E.Errors(errors...)
  51. }