conn.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package trackerconn
  2. import (
  3. "net"
  4. "github.com/sagernet/sing/common/buf"
  5. "github.com/sagernet/sing/common/bufio"
  6. N "github.com/sagernet/sing/common/network"
  7. "go.uber.org/atomic"
  8. )
  9. func New(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *Conn {
  10. return &Conn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
  11. }
  12. func NewHook(conn net.Conn, readCounter func(n int64), writeCounter func(n int64)) *HookConn {
  13. return &HookConn{bufio.NewExtendedConn(conn), readCounter, writeCounter}
  14. }
  15. type Conn struct {
  16. N.ExtendedConn
  17. readCounter []*atomic.Int64
  18. writeCounter []*atomic.Int64
  19. }
  20. func (c *Conn) Read(p []byte) (n int, err error) {
  21. n, err = c.ExtendedConn.Read(p)
  22. for _, counter := range c.readCounter {
  23. counter.Add(int64(n))
  24. }
  25. return n, err
  26. }
  27. func (c *Conn) ReadBuffer(buffer *buf.Buffer) error {
  28. err := c.ExtendedConn.ReadBuffer(buffer)
  29. if err != nil {
  30. return err
  31. }
  32. for _, counter := range c.readCounter {
  33. counter.Add(int64(buffer.Len()))
  34. }
  35. return nil
  36. }
  37. func (c *Conn) Write(p []byte) (n int, err error) {
  38. n, err = c.ExtendedConn.Write(p)
  39. for _, counter := range c.writeCounter {
  40. counter.Add(int64(n))
  41. }
  42. return n, err
  43. }
  44. func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
  45. dataLen := int64(buffer.Len())
  46. err := c.ExtendedConn.WriteBuffer(buffer)
  47. if err != nil {
  48. return err
  49. }
  50. for _, counter := range c.writeCounter {
  51. counter.Add(dataLen)
  52. }
  53. return nil
  54. }
  55. func (c *Conn) Upstream() any {
  56. return c.ExtendedConn
  57. }
  58. type HookConn struct {
  59. N.ExtendedConn
  60. readCounter func(n int64)
  61. writeCounter func(n int64)
  62. }
  63. func (c *HookConn) Read(p []byte) (n int, err error) {
  64. n, err = c.ExtendedConn.Read(p)
  65. c.readCounter(int64(n))
  66. return n, err
  67. }
  68. func (c *HookConn) ReadBuffer(buffer *buf.Buffer) error {
  69. err := c.ExtendedConn.ReadBuffer(buffer)
  70. if err != nil {
  71. return err
  72. }
  73. c.readCounter(int64(buffer.Len()))
  74. return nil
  75. }
  76. func (c *HookConn) Write(p []byte) (n int, err error) {
  77. n, err = c.ExtendedConn.Write(p)
  78. c.writeCounter(int64(n))
  79. return n, err
  80. }
  81. func (c *HookConn) WriteBuffer(buffer *buf.Buffer) error {
  82. dataLen := int64(buffer.Len())
  83. err := c.ExtendedConn.WriteBuffer(buffer)
  84. if err != nil {
  85. return err
  86. }
  87. c.writeCounter(dataLen)
  88. return nil
  89. }
  90. func (c *HookConn) Upstream() any {
  91. return c.ExtendedConn
  92. }