packet_conn.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. package trackerconn
  2. import (
  3. "github.com/sagernet/sing/common/buf"
  4. M "github.com/sagernet/sing/common/metadata"
  5. N "github.com/sagernet/sing/common/network"
  6. "go.uber.org/atomic"
  7. )
  8. func NewPacket(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *PacketConn {
  9. return &PacketConn{conn, readCounter, writeCounter}
  10. }
  11. func NewHookPacket(conn N.PacketConn, readCounter func(n int64), writeCounter func(n int64)) *HookPacketConn {
  12. return &HookPacketConn{conn, readCounter, writeCounter}
  13. }
  14. type PacketConn struct {
  15. N.PacketConn
  16. readCounter []*atomic.Int64
  17. writeCounter []*atomic.Int64
  18. }
  19. func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
  20. destination, err = c.PacketConn.ReadPacket(buffer)
  21. if err == nil {
  22. for _, counter := range c.readCounter {
  23. counter.Add(int64(buffer.Len()))
  24. }
  25. }
  26. return
  27. }
  28. func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  29. dataLen := int64(buffer.Len())
  30. err := c.PacketConn.WritePacket(buffer, destination)
  31. if err != nil {
  32. return err
  33. }
  34. for _, counter := range c.writeCounter {
  35. counter.Add(dataLen)
  36. }
  37. return nil
  38. }
  39. func (c *PacketConn) Upstream() any {
  40. return c.PacketConn
  41. }
  42. type HookPacketConn struct {
  43. N.PacketConn
  44. readCounter func(n int64)
  45. writeCounter func(n int64)
  46. }
  47. func (c *HookPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
  48. destination, err = c.PacketConn.ReadPacket(buffer)
  49. if err == nil {
  50. c.readCounter(int64(buffer.Len()))
  51. }
  52. return
  53. }
  54. func (c *HookPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  55. dataLen := int64(buffer.Len())
  56. err := c.PacketConn.WritePacket(buffer, destination)
  57. if err != nil {
  58. return err
  59. }
  60. c.writeCounter(dataLen)
  61. return nil
  62. }
  63. func (c *HookPacketConn) Upstream() any {
  64. return c.PacketConn
  65. }