manager.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package trafficcontrol
  2. import (
  3. "io"
  4. "net"
  5. "sync"
  6. "sync/atomic"
  7. "github.com/sagernet/sing/common/buf"
  8. "github.com/sagernet/sing/common/bufio"
  9. M "github.com/sagernet/sing/common/metadata"
  10. N "github.com/sagernet/sing/common/network"
  11. )
  12. type Manager[U comparable] struct {
  13. access sync.Mutex
  14. users map[U]*Traffic
  15. }
  16. type Traffic struct {
  17. Upload uint64
  18. Download uint64
  19. }
  20. func NewManager[U comparable]() *Manager[U] {
  21. return &Manager[U]{
  22. users: make(map[U]*Traffic),
  23. }
  24. }
  25. func (m *Manager[U]) Reset() {
  26. m.users = make(map[U]*Traffic)
  27. }
  28. func (m *Manager[U]) TrackConnection(user U, conn net.Conn) net.Conn {
  29. m.access.Lock()
  30. defer m.access.Unlock()
  31. var traffic *Traffic
  32. if t, loaded := m.users[user]; loaded {
  33. traffic = t
  34. } else {
  35. traffic = new(Traffic)
  36. m.users[user] = traffic
  37. }
  38. return &TrackConn{conn, traffic}
  39. }
  40. func (m *Manager[U]) TrackPacketConnection(user U, conn N.PacketConn) N.PacketConn {
  41. m.access.Lock()
  42. defer m.access.Unlock()
  43. var traffic *Traffic
  44. if t, loaded := m.users[user]; loaded {
  45. traffic = t
  46. } else {
  47. traffic = new(Traffic)
  48. m.users[user] = traffic
  49. }
  50. return &TrackPacketConn{conn, traffic}
  51. }
  52. func (m *Manager[U]) ReadTraffics() map[U]Traffic {
  53. m.access.Lock()
  54. defer m.access.Unlock()
  55. trafficMap := make(map[U]Traffic)
  56. for user, traffic := range m.users {
  57. upload := atomic.SwapUint64(&traffic.Upload, 0)
  58. download := atomic.SwapUint64(&traffic.Download, 0)
  59. if upload == 0 && download == 0 {
  60. continue
  61. }
  62. trafficMap[user] = Traffic{
  63. Upload: upload,
  64. Download: download,
  65. }
  66. }
  67. return trafficMap
  68. }
  69. type TrackConn struct {
  70. net.Conn
  71. *Traffic
  72. }
  73. func (c *TrackConn) Read(p []byte) (n int, err error) {
  74. n, err = c.Conn.Read(p)
  75. if n > 0 {
  76. atomic.AddUint64(&c.Upload, uint64(n))
  77. }
  78. return
  79. }
  80. func (c *TrackConn) Write(p []byte) (n int, err error) {
  81. n, err = c.Conn.Write(p)
  82. if n > 0 {
  83. atomic.AddUint64(&c.Download, uint64(n))
  84. }
  85. return
  86. }
  87. func (c *TrackConn) WriteTo(w io.Writer) (n int64, err error) {
  88. n, err = bufio.Copy(w, c.Conn)
  89. if n > 0 {
  90. atomic.AddUint64(&c.Upload, uint64(n))
  91. }
  92. return
  93. }
  94. func (c *TrackConn) ReadFrom(r io.Reader) (n int64, err error) {
  95. n, err = bufio.Copy(c.Conn, r)
  96. if n > 0 {
  97. atomic.AddUint64(&c.Download, uint64(n))
  98. }
  99. return
  100. }
  101. func (c *TrackConn) Upstream() any {
  102. return c.Conn
  103. }
  104. type TrackPacketConn struct {
  105. N.PacketConn
  106. *Traffic
  107. }
  108. func (c *TrackPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
  109. destination, err := c.PacketConn.ReadPacket(buffer)
  110. if err == nil {
  111. atomic.AddUint64(&c.Upload, uint64(buffer.Len()))
  112. }
  113. return destination, err
  114. }
  115. func (c *TrackPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
  116. n := buffer.Len()
  117. err := c.PacketConn.WritePacket(buffer, destination)
  118. if err == nil {
  119. atomic.AddUint64(&c.Download, uint64(n))
  120. }
  121. return err
  122. }
  123. func (c *TrackPacketConn) Upstream() any {
  124. return c.PacketConn
  125. }