1
0

device_nat.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package wireguard
  2. import (
  3. "context"
  4. "sync/atomic"
  5. "time"
  6. "github.com/sagernet/sing-box/adapter"
  7. "github.com/sagernet/sing-box/log"
  8. "github.com/sagernet/sing-tun"
  9. "github.com/sagernet/sing-tun/ping"
  10. "github.com/sagernet/sing/common/buf"
  11. "github.com/sagernet/sing/common/logger"
  12. )
  13. var _ Device = (*natDeviceWrapper)(nil)
  14. type natDeviceWrapper struct {
  15. Device
  16. ctx context.Context
  17. logger logger.ContextLogger
  18. packetOutbound chan *buf.Buffer
  19. rewriter *ping.Rewriter
  20. buffer [][]byte
  21. }
  22. func NewNATDevice(ctx context.Context, logger logger.ContextLogger, upstream Device) NatDevice {
  23. wrapper := &natDeviceWrapper{
  24. Device: upstream,
  25. ctx: ctx,
  26. logger: logger,
  27. packetOutbound: make(chan *buf.Buffer, 256),
  28. rewriter: ping.NewRewriter(ctx, logger, upstream.Inet4Address(), upstream.Inet6Address()),
  29. }
  30. return wrapper
  31. }
  32. func (d *natDeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
  33. select {
  34. case packet := <-d.packetOutbound:
  35. defer packet.Release()
  36. sizes[0] = copy(bufs[0][offset:], packet.Bytes())
  37. return 1, nil
  38. default:
  39. }
  40. return d.Device.Read(bufs, sizes, offset)
  41. }
  42. func (d *natDeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
  43. for _, buffer := range bufs {
  44. handled, err := d.rewriter.WriteBack(buffer[offset:])
  45. if handled {
  46. if err != nil {
  47. return 0, err
  48. }
  49. } else {
  50. d.buffer = append(d.buffer, buffer)
  51. }
  52. }
  53. if len(d.buffer) > 0 {
  54. _, err := d.Device.Write(d.buffer, offset)
  55. if err != nil {
  56. return 0, err
  57. }
  58. d.buffer = d.buffer[:0]
  59. }
  60. return 0, nil
  61. }
  62. func (d *natDeviceWrapper) CreateDestination(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
  63. ctx := log.ContextWithNewID(d.ctx)
  64. session := tun.DirectRouteSession{
  65. Source: metadata.Source.Addr,
  66. Destination: metadata.Destination.Addr,
  67. }
  68. d.rewriter.CreateSession(session, routeContext)
  69. d.logger.InfoContext(ctx, "linked ", metadata.Network, " connection from ", metadata.Source.AddrString(), " to ", metadata.Destination.AddrString())
  70. return &natDestination{device: d, session: session}, nil
  71. }
  72. var _ tun.DirectRouteDestination = (*natDestination)(nil)
  73. type natDestination struct {
  74. device *natDeviceWrapper
  75. session tun.DirectRouteSession
  76. closed atomic.Bool
  77. }
  78. func (d *natDestination) WritePacket(buffer *buf.Buffer) error {
  79. d.device.rewriter.RewritePacket(buffer.Bytes())
  80. d.device.packetOutbound <- buffer
  81. return nil
  82. }
  83. func (d *natDestination) Close() error {
  84. d.closed.Store(true)
  85. d.device.rewriter.DeleteSession(d.session)
  86. return nil
  87. }
  88. func (d *natDestination) IsClosed() bool {
  89. return d.closed.Load()
  90. }