manager.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. package outbound
  2. import (
  3. "context"
  4. "io"
  5. "os"
  6. "strings"
  7. "sync"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/taskmonitor"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing/common"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. )
  16. var _ adapter.OutboundManager = (*Manager)(nil)
  17. type Manager struct {
  18. logger log.ContextLogger
  19. registry adapter.OutboundRegistry
  20. defaultTag string
  21. access sync.Mutex
  22. started bool
  23. stage adapter.StartStage
  24. outbounds []adapter.Outbound
  25. outboundByTag map[string]adapter.Outbound
  26. dependByTag map[string][]string
  27. defaultOutbound adapter.Outbound
  28. defaultOutboundFallback adapter.Outbound
  29. }
  30. func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager {
  31. return &Manager{
  32. logger: logger,
  33. registry: registry,
  34. defaultTag: defaultTag,
  35. outboundByTag: make(map[string]adapter.Outbound),
  36. dependByTag: make(map[string][]string),
  37. }
  38. }
  39. func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) {
  40. m.defaultOutboundFallback = defaultOutboundFallback
  41. }
  42. func (m *Manager) Start(stage adapter.StartStage) error {
  43. m.access.Lock()
  44. if m.started && m.stage >= stage {
  45. panic("already started")
  46. }
  47. m.started = true
  48. m.stage = stage
  49. outbounds := m.outbounds
  50. m.access.Unlock()
  51. if stage == adapter.StartStateStart {
  52. return m.startOutbounds(outbounds)
  53. } else {
  54. for _, outbound := range outbounds {
  55. err := adapter.LegacyStart(outbound, stage)
  56. if err != nil {
  57. return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  58. }
  59. }
  60. }
  61. return nil
  62. }
  63. func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
  64. monitor := taskmonitor.New(m.logger, C.StartTimeout)
  65. started := make(map[string]bool)
  66. for {
  67. canContinue := false
  68. startOne:
  69. for _, outboundToStart := range outbounds {
  70. outboundTag := outboundToStart.Tag()
  71. if started[outboundTag] {
  72. continue
  73. }
  74. dependencies := outboundToStart.Dependencies()
  75. for _, dependency := range dependencies {
  76. if !started[dependency] {
  77. continue startOne
  78. }
  79. }
  80. started[outboundTag] = true
  81. canContinue = true
  82. if starter, isStarter := outboundToStart.(interface {
  83. Start() error
  84. }); isStarter {
  85. monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  86. err := starter.Start()
  87. monitor.Finish()
  88. if err != nil {
  89. return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  90. }
  91. }
  92. }
  93. if len(started) == len(outbounds) {
  94. break
  95. }
  96. if canContinue {
  97. continue
  98. }
  99. currentOutbound := common.Find(outbounds, func(it adapter.Outbound) bool {
  100. return !started[it.Tag()]
  101. })
  102. var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
  103. lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
  104. problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
  105. return !started[it]
  106. })
  107. if common.Contains(oTree, problemOutboundTag) {
  108. return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
  109. }
  110. m.access.Lock()
  111. problemOutbound := m.outboundByTag[problemOutboundTag]
  112. m.access.Unlock()
  113. if problemOutbound == nil {
  114. return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]")
  115. }
  116. return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
  117. }
  118. return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound)
  119. }
  120. return nil
  121. }
  122. func (m *Manager) Close() error {
  123. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  124. m.access.Lock()
  125. if !m.started {
  126. m.access.Unlock()
  127. return nil
  128. }
  129. m.started = false
  130. outbounds := m.outbounds
  131. m.outbounds = nil
  132. m.access.Unlock()
  133. var err error
  134. for _, outbound := range outbounds {
  135. if closer, isCloser := outbound.(io.Closer); isCloser {
  136. monitor.Start("close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  137. err = E.Append(err, closer.Close(), func(err error) error {
  138. return E.Cause(err, "close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  139. })
  140. monitor.Finish()
  141. }
  142. }
  143. return nil
  144. }
  145. func (m *Manager) Outbounds() []adapter.Outbound {
  146. m.access.Lock()
  147. defer m.access.Unlock()
  148. return m.outbounds
  149. }
  150. func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
  151. m.access.Lock()
  152. defer m.access.Unlock()
  153. outbound, found := m.outboundByTag[tag]
  154. return outbound, found
  155. }
  156. func (m *Manager) Default() adapter.Outbound {
  157. m.access.Lock()
  158. defer m.access.Unlock()
  159. if m.defaultOutbound != nil {
  160. return m.defaultOutbound
  161. } else {
  162. return m.defaultOutboundFallback
  163. }
  164. }
  165. func (m *Manager) Remove(tag string) error {
  166. m.access.Lock()
  167. outbound, found := m.outboundByTag[tag]
  168. if !found {
  169. m.access.Unlock()
  170. return os.ErrInvalid
  171. }
  172. delete(m.outboundByTag, tag)
  173. index := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  174. return it == outbound
  175. })
  176. if index == -1 {
  177. panic("invalid inbound index")
  178. }
  179. m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...)
  180. started := m.started
  181. if m.defaultOutbound == outbound {
  182. if len(m.outbounds) > 0 {
  183. m.defaultOutbound = m.outbounds[0]
  184. m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
  185. } else {
  186. m.defaultOutbound = nil
  187. }
  188. }
  189. dependBy := m.dependByTag[tag]
  190. if len(dependBy) > 0 {
  191. return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", "))
  192. }
  193. dependencies := outbound.Dependencies()
  194. for _, dependency := range dependencies {
  195. if len(m.dependByTag[dependency]) == 1 {
  196. delete(m.dependByTag, dependency)
  197. } else {
  198. m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
  199. return it != tag
  200. })
  201. }
  202. }
  203. m.access.Unlock()
  204. if started {
  205. return common.Close(outbound)
  206. }
  207. return nil
  208. }
  209. func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error {
  210. if tag == "" {
  211. return os.ErrInvalid
  212. }
  213. outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options)
  214. if err != nil {
  215. return err
  216. }
  217. m.access.Lock()
  218. defer m.access.Unlock()
  219. if m.started {
  220. for _, stage := range adapter.ListStartStages {
  221. err = adapter.LegacyStart(outbound, stage)
  222. if err != nil {
  223. return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  224. }
  225. }
  226. }
  227. if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
  228. if m.started {
  229. err = common.Close(existsOutbound)
  230. if err != nil {
  231. return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]")
  232. }
  233. }
  234. existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  235. return it == existsOutbound
  236. })
  237. if existsIndex == -1 {
  238. panic("invalid inbound index")
  239. }
  240. m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...)
  241. }
  242. m.outbounds = append(m.outbounds, outbound)
  243. m.outboundByTag[tag] = outbound
  244. dependencies := outbound.Dependencies()
  245. for _, dependency := range dependencies {
  246. m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
  247. }
  248. if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) {
  249. m.defaultOutbound = outbound
  250. if m.started {
  251. m.logger.Info("updated default outbound to ", outbound.Tag())
  252. }
  253. }
  254. return nil
  255. }