manager.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package endpoint
  2. import (
  3. "context"
  4. "os"
  5. "sync"
  6. "github.com/sagernet/sing-box/adapter"
  7. "github.com/sagernet/sing-box/common/taskmonitor"
  8. C "github.com/sagernet/sing-box/constant"
  9. "github.com/sagernet/sing-box/log"
  10. "github.com/sagernet/sing/common"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. )
  13. var _ adapter.EndpointManager = (*Manager)(nil)
  14. type Manager struct {
  15. logger log.ContextLogger
  16. registry adapter.EndpointRegistry
  17. access sync.Mutex
  18. started bool
  19. stage adapter.StartStage
  20. endpoints []adapter.Endpoint
  21. endpointByTag map[string]adapter.Endpoint
  22. }
  23. func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
  24. return &Manager{
  25. logger: logger,
  26. registry: registry,
  27. endpointByTag: make(map[string]adapter.Endpoint),
  28. }
  29. }
  30. func (m *Manager) Start(stage adapter.StartStage) error {
  31. m.access.Lock()
  32. defer m.access.Unlock()
  33. if m.started && m.stage >= stage {
  34. panic("already started")
  35. }
  36. m.started = true
  37. m.stage = stage
  38. if stage == adapter.StartStateStart {
  39. // started with outbound manager
  40. return nil
  41. }
  42. for _, endpoint := range m.endpoints {
  43. err := adapter.LegacyStart(endpoint, stage)
  44. if err != nil {
  45. return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
  46. }
  47. }
  48. return nil
  49. }
  50. func (m *Manager) Close() error {
  51. m.access.Lock()
  52. defer m.access.Unlock()
  53. if !m.started {
  54. return nil
  55. }
  56. m.started = false
  57. endpoints := m.endpoints
  58. m.endpoints = nil
  59. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  60. var err error
  61. for _, endpoint := range endpoints {
  62. monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
  63. err = E.Append(err, endpoint.Close(), func(err error) error {
  64. return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
  65. })
  66. monitor.Finish()
  67. }
  68. return nil
  69. }
  70. func (m *Manager) Endpoints() []adapter.Endpoint {
  71. m.access.Lock()
  72. defer m.access.Unlock()
  73. return m.endpoints
  74. }
  75. func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
  76. m.access.Lock()
  77. defer m.access.Unlock()
  78. endpoint, found := m.endpointByTag[tag]
  79. return endpoint, found
  80. }
  81. func (m *Manager) Remove(tag string) error {
  82. m.access.Lock()
  83. endpoint, found := m.endpointByTag[tag]
  84. if !found {
  85. m.access.Unlock()
  86. return os.ErrInvalid
  87. }
  88. delete(m.endpointByTag, tag)
  89. index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
  90. return it == endpoint
  91. })
  92. if index == -1 {
  93. panic("invalid endpoint index")
  94. }
  95. m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
  96. started := m.started
  97. m.access.Unlock()
  98. if started {
  99. return endpoint.Close()
  100. }
  101. return nil
  102. }
  103. func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
  104. endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
  105. if err != nil {
  106. return err
  107. }
  108. m.access.Lock()
  109. defer m.access.Unlock()
  110. if m.started {
  111. for _, stage := range adapter.ListStartStages {
  112. err = adapter.LegacyStart(endpoint, stage)
  113. if err != nil {
  114. return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
  115. }
  116. }
  117. }
  118. if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
  119. if m.started {
  120. err = existsEndpoint.Close()
  121. if err != nil {
  122. return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
  123. }
  124. }
  125. existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
  126. return it == existsEndpoint
  127. })
  128. if existsIndex == -1 {
  129. panic("invalid endpoint index")
  130. }
  131. m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
  132. }
  133. m.endpoints = append(m.endpoints, endpoint)
  134. m.endpointByTag[tag] = endpoint
  135. return nil
  136. }