manager.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package outbound
  2. import (
  3. "context"
  4. "io"
  5. "os"
  6. "strings"
  7. "sync"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/taskmonitor"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing/common"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. )
  16. var _ adapter.OutboundManager = (*Manager)(nil)
  17. type Manager struct {
  18. logger log.ContextLogger
  19. registry adapter.OutboundRegistry
  20. endpoint adapter.EndpointManager
  21. defaultTag string
  22. access sync.RWMutex
  23. started bool
  24. stage adapter.StartStage
  25. outbounds []adapter.Outbound
  26. outboundByTag map[string]adapter.Outbound
  27. dependByTag map[string][]string
  28. defaultOutbound adapter.Outbound
  29. defaultOutboundFallback func() (adapter.Outbound, error)
  30. }
  31. func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
  32. return &Manager{
  33. logger: logger,
  34. registry: registry,
  35. endpoint: endpoint,
  36. defaultTag: defaultTag,
  37. outboundByTag: make(map[string]adapter.Outbound),
  38. dependByTag: make(map[string][]string),
  39. }
  40. }
  41. func (m *Manager) Initialize(defaultOutboundFallback func() (adapter.Outbound, error)) {
  42. m.defaultOutboundFallback = defaultOutboundFallback
  43. }
  44. func (m *Manager) Start(stage adapter.StartStage) error {
  45. m.access.Lock()
  46. if m.started && m.stage >= stage {
  47. panic("already started")
  48. }
  49. m.started = true
  50. m.stage = stage
  51. if stage == adapter.StartStateStart {
  52. if m.defaultTag != "" && m.defaultOutbound == nil {
  53. defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
  54. if !loaded {
  55. m.access.Unlock()
  56. return E.New("default outbound not found: ", m.defaultTag)
  57. }
  58. m.defaultOutbound = defaultEndpoint
  59. }
  60. if m.defaultOutbound == nil {
  61. directOutbound, err := m.defaultOutboundFallback()
  62. if err != nil {
  63. m.access.Unlock()
  64. return E.Cause(err, "create direct outbound for fallback")
  65. }
  66. m.outbounds = append(m.outbounds, directOutbound)
  67. m.outboundByTag[directOutbound.Tag()] = directOutbound
  68. m.defaultOutbound = directOutbound
  69. }
  70. outbounds := m.outbounds
  71. m.access.Unlock()
  72. return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
  73. } else {
  74. outbounds := m.outbounds
  75. m.access.Unlock()
  76. for _, outbound := range outbounds {
  77. err := adapter.LegacyStart(outbound, stage)
  78. if err != nil {
  79. return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  80. }
  81. }
  82. }
  83. return nil
  84. }
  85. func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
  86. monitor := taskmonitor.New(m.logger, C.StartTimeout)
  87. started := make(map[string]bool)
  88. for {
  89. canContinue := false
  90. startOne:
  91. for _, outboundToStart := range outbounds {
  92. outboundTag := outboundToStart.Tag()
  93. if started[outboundTag] {
  94. continue
  95. }
  96. dependencies := outboundToStart.Dependencies()
  97. for _, dependency := range dependencies {
  98. if !started[dependency] {
  99. continue startOne
  100. }
  101. }
  102. started[outboundTag] = true
  103. canContinue = true
  104. if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
  105. monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  106. err := starter.Start(adapter.StartStateStart)
  107. monitor.Finish()
  108. if err != nil {
  109. return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  110. }
  111. } else if starter, isStarter := outboundToStart.(interface {
  112. Start() error
  113. }); isStarter {
  114. monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  115. err := starter.Start()
  116. monitor.Finish()
  117. if err != nil {
  118. return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  119. }
  120. }
  121. }
  122. if len(started) == len(outbounds) {
  123. break
  124. }
  125. if canContinue {
  126. continue
  127. }
  128. currentOutbound := common.Find(outbounds, func(it adapter.Outbound) bool {
  129. return !started[it.Tag()]
  130. })
  131. var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
  132. lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
  133. problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
  134. return !started[it]
  135. })
  136. if common.Contains(oTree, problemOutboundTag) {
  137. return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
  138. }
  139. m.access.Lock()
  140. problemOutbound := m.outboundByTag[problemOutboundTag]
  141. m.access.Unlock()
  142. if problemOutbound == nil {
  143. return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]")
  144. }
  145. return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
  146. }
  147. return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound)
  148. }
  149. return nil
  150. }
  151. func (m *Manager) Close() error {
  152. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  153. m.access.Lock()
  154. if !m.started {
  155. m.access.Unlock()
  156. return nil
  157. }
  158. m.started = false
  159. outbounds := m.outbounds
  160. m.outbounds = nil
  161. m.access.Unlock()
  162. var err error
  163. for _, outbound := range outbounds {
  164. if closer, isCloser := outbound.(io.Closer); isCloser {
  165. monitor.Start("close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  166. err = E.Append(err, closer.Close(), func(err error) error {
  167. return E.Cause(err, "close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  168. })
  169. monitor.Finish()
  170. }
  171. }
  172. return nil
  173. }
  174. func (m *Manager) Outbounds() []adapter.Outbound {
  175. m.access.RLock()
  176. defer m.access.RUnlock()
  177. return m.outbounds
  178. }
  179. func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
  180. m.access.RLock()
  181. outbound, found := m.outboundByTag[tag]
  182. m.access.RUnlock()
  183. if found {
  184. return outbound, true
  185. }
  186. return m.endpoint.Get(tag)
  187. }
  188. func (m *Manager) Default() adapter.Outbound {
  189. m.access.RLock()
  190. defer m.access.RUnlock()
  191. return m.defaultOutbound
  192. }
  193. func (m *Manager) Remove(tag string) error {
  194. m.access.Lock()
  195. defer m.access.Unlock()
  196. outbound, found := m.outboundByTag[tag]
  197. if !found {
  198. return os.ErrInvalid
  199. }
  200. delete(m.outboundByTag, tag)
  201. index := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  202. return it == outbound
  203. })
  204. if index == -1 {
  205. panic("invalid inbound index")
  206. }
  207. m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...)
  208. started := m.started
  209. if m.defaultOutbound == outbound {
  210. if len(m.outbounds) > 0 {
  211. m.defaultOutbound = m.outbounds[0]
  212. m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
  213. } else {
  214. m.defaultOutbound = nil
  215. }
  216. }
  217. dependBy := m.dependByTag[tag]
  218. if len(dependBy) > 0 {
  219. return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", "))
  220. }
  221. dependencies := outbound.Dependencies()
  222. for _, dependency := range dependencies {
  223. if len(m.dependByTag[dependency]) == 1 {
  224. delete(m.dependByTag, dependency)
  225. } else {
  226. m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
  227. return it != tag
  228. })
  229. }
  230. }
  231. if started {
  232. return common.Close(outbound)
  233. }
  234. return nil
  235. }
  236. func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error {
  237. if tag == "" {
  238. return os.ErrInvalid
  239. }
  240. outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options)
  241. if err != nil {
  242. return err
  243. }
  244. if m.started {
  245. for _, stage := range adapter.ListStartStages {
  246. err = adapter.LegacyStart(outbound, stage)
  247. if err != nil {
  248. return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  249. }
  250. }
  251. }
  252. m.access.Lock()
  253. defer m.access.Unlock()
  254. if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
  255. if m.started {
  256. err = common.Close(existsOutbound)
  257. if err != nil {
  258. return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]")
  259. }
  260. }
  261. existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  262. return it == existsOutbound
  263. })
  264. if existsIndex == -1 {
  265. panic("invalid inbound index")
  266. }
  267. m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...)
  268. }
  269. m.outbounds = append(m.outbounds, outbound)
  270. m.outboundByTag[tag] = outbound
  271. dependencies := outbound.Dependencies()
  272. for _, dependency := range dependencies {
  273. m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
  274. }
  275. if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) {
  276. m.defaultOutbound = outbound
  277. if m.started {
  278. m.logger.Info("updated default outbound to ", outbound.Tag())
  279. }
  280. }
  281. return nil
  282. }