manager.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package endpoint
  2. import (
  3. "context"
  4. "os"
  5. "sync"
  6. "time"
  7. "github.com/sagernet/sing-box/adapter"
  8. "github.com/sagernet/sing-box/common/taskmonitor"
  9. C "github.com/sagernet/sing-box/constant"
  10. "github.com/sagernet/sing-box/log"
  11. "github.com/sagernet/sing/common"
  12. E "github.com/sagernet/sing/common/exceptions"
  13. F "github.com/sagernet/sing/common/format"
  14. )
  15. var _ adapter.EndpointManager = (*Manager)(nil)
  16. type Manager struct {
  17. logger log.ContextLogger
  18. registry adapter.EndpointRegistry
  19. access sync.Mutex
  20. started bool
  21. stage adapter.StartStage
  22. endpoints []adapter.Endpoint
  23. endpointByTag map[string]adapter.Endpoint
  24. }
  25. func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
  26. return &Manager{
  27. logger: logger,
  28. registry: registry,
  29. endpointByTag: make(map[string]adapter.Endpoint),
  30. }
  31. }
  32. func (m *Manager) Start(stage adapter.StartStage) error {
  33. m.access.Lock()
  34. defer m.access.Unlock()
  35. if m.started && m.stage >= stage {
  36. panic("already started")
  37. }
  38. m.started = true
  39. m.stage = stage
  40. if stage == adapter.StartStateStart {
  41. // started with outbound manager
  42. return nil
  43. }
  44. for _, endpoint := range m.endpoints {
  45. name := "endpoint/" + endpoint.Type() + "[" + endpoint.Tag() + "]"
  46. m.logger.Trace(stage, " ", name)
  47. startTime := time.Now()
  48. err := adapter.LegacyStart(endpoint, stage)
  49. if err != nil {
  50. return E.Cause(err, stage, " ", name)
  51. }
  52. m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  53. }
  54. return nil
  55. }
  56. func (m *Manager) Close() error {
  57. m.access.Lock()
  58. defer m.access.Unlock()
  59. if !m.started {
  60. return nil
  61. }
  62. m.started = false
  63. endpoints := m.endpoints
  64. m.endpoints = nil
  65. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  66. var err error
  67. for _, endpoint := range endpoints {
  68. name := "endpoint/" + endpoint.Type() + "[" + endpoint.Tag() + "]"
  69. m.logger.Trace("close ", name)
  70. startTime := time.Now()
  71. monitor.Start("close ", name)
  72. err = E.Append(err, endpoint.Close(), func(err error) error {
  73. return E.Cause(err, "close ", name)
  74. })
  75. monitor.Finish()
  76. m.logger.Trace("close ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  77. }
  78. return nil
  79. }
  80. func (m *Manager) Endpoints() []adapter.Endpoint {
  81. m.access.Lock()
  82. defer m.access.Unlock()
  83. return m.endpoints
  84. }
  85. func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
  86. m.access.Lock()
  87. defer m.access.Unlock()
  88. endpoint, found := m.endpointByTag[tag]
  89. return endpoint, found
  90. }
  91. func (m *Manager) Remove(tag string) error {
  92. m.access.Lock()
  93. endpoint, found := m.endpointByTag[tag]
  94. if !found {
  95. m.access.Unlock()
  96. return os.ErrInvalid
  97. }
  98. delete(m.endpointByTag, tag)
  99. index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
  100. return it == endpoint
  101. })
  102. if index == -1 {
  103. panic("invalid endpoint index")
  104. }
  105. m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
  106. started := m.started
  107. m.access.Unlock()
  108. if started {
  109. return endpoint.Close()
  110. }
  111. return nil
  112. }
  113. func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
  114. endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
  115. if err != nil {
  116. return err
  117. }
  118. m.access.Lock()
  119. defer m.access.Unlock()
  120. if m.started {
  121. name := "endpoint/" + endpoint.Type() + "[" + endpoint.Tag() + "]"
  122. for _, stage := range adapter.ListStartStages {
  123. m.logger.Trace(stage, " ", name)
  124. startTime := time.Now()
  125. err = adapter.LegacyStart(endpoint, stage)
  126. if err != nil {
  127. return E.Cause(err, stage, " ", name)
  128. }
  129. m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  130. }
  131. }
  132. if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
  133. if m.started {
  134. err = existsEndpoint.Close()
  135. if err != nil {
  136. return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
  137. }
  138. }
  139. existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
  140. return it == existsEndpoint
  141. })
  142. if existsIndex == -1 {
  143. panic("invalid endpoint index")
  144. }
  145. m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
  146. }
  147. m.endpoints = append(m.endpoints, endpoint)
  148. m.endpointByTag[tag] = endpoint
  149. return nil
  150. }