manager.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package outbound
  2. import (
  3. "context"
  4. "io"
  5. "os"
  6. "strings"
  7. "sync"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/taskmonitor"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing/common"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. )
  16. var _ adapter.OutboundManager = (*Manager)(nil)
  17. type Manager struct {
  18. logger log.ContextLogger
  19. registry adapter.OutboundRegistry
  20. endpoint adapter.EndpointManager
  21. defaultTag string
  22. access sync.RWMutex
  23. started bool
  24. stage adapter.StartStage
  25. outbounds []adapter.Outbound
  26. outboundByTag map[string]adapter.Outbound
  27. dependByTag map[string][]string
  28. defaultOutbound adapter.Outbound
  29. defaultOutboundFallback adapter.Outbound
  30. }
  31. func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
  32. return &Manager{
  33. logger: logger,
  34. registry: registry,
  35. endpoint: endpoint,
  36. defaultTag: defaultTag,
  37. outboundByTag: make(map[string]adapter.Outbound),
  38. dependByTag: make(map[string][]string),
  39. }
  40. }
  41. func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) {
  42. m.defaultOutboundFallback = defaultOutboundFallback
  43. }
  44. func (m *Manager) Start(stage adapter.StartStage) error {
  45. m.access.Lock()
  46. if m.started && m.stage >= stage {
  47. panic("already started")
  48. }
  49. m.started = true
  50. m.stage = stage
  51. outbounds := m.outbounds
  52. m.access.Unlock()
  53. if stage == adapter.StartStateStart {
  54. if m.defaultTag != "" && m.defaultOutbound == nil {
  55. defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
  56. if !loaded {
  57. return E.New("default outbound not found: ", m.defaultTag)
  58. }
  59. m.defaultOutbound = defaultEndpoint
  60. }
  61. return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
  62. } else {
  63. for _, outbound := range outbounds {
  64. err := adapter.LegacyStart(outbound, stage)
  65. if err != nil {
  66. return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  67. }
  68. }
  69. }
  70. return nil
  71. }
  72. func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
  73. monitor := taskmonitor.New(m.logger, C.StartTimeout)
  74. started := make(map[string]bool)
  75. for {
  76. canContinue := false
  77. startOne:
  78. for _, outboundToStart := range outbounds {
  79. outboundTag := outboundToStart.Tag()
  80. if started[outboundTag] {
  81. continue
  82. }
  83. dependencies := outboundToStart.Dependencies()
  84. for _, dependency := range dependencies {
  85. if !started[dependency] {
  86. continue startOne
  87. }
  88. }
  89. started[outboundTag] = true
  90. canContinue = true
  91. if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
  92. monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  93. err := starter.Start(adapter.StartStateStart)
  94. monitor.Finish()
  95. if err != nil {
  96. return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  97. }
  98. } else if starter, isStarter := outboundToStart.(interface {
  99. Start() error
  100. }); isStarter {
  101. monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  102. err := starter.Start()
  103. monitor.Finish()
  104. if err != nil {
  105. return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
  106. }
  107. }
  108. }
  109. if len(started) == len(outbounds) {
  110. break
  111. }
  112. if canContinue {
  113. continue
  114. }
  115. currentOutbound := common.Find(outbounds, func(it adapter.Outbound) bool {
  116. return !started[it.Tag()]
  117. })
  118. var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
  119. lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
  120. problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
  121. return !started[it]
  122. })
  123. if common.Contains(oTree, problemOutboundTag) {
  124. return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
  125. }
  126. m.access.Lock()
  127. problemOutbound := m.outboundByTag[problemOutboundTag]
  128. m.access.Unlock()
  129. if problemOutbound == nil {
  130. return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]")
  131. }
  132. return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
  133. }
  134. return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound)
  135. }
  136. return nil
  137. }
  138. func (m *Manager) Close() error {
  139. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  140. m.access.Lock()
  141. if !m.started {
  142. m.access.Unlock()
  143. return nil
  144. }
  145. m.started = false
  146. outbounds := m.outbounds
  147. m.outbounds = nil
  148. m.access.Unlock()
  149. var err error
  150. for _, outbound := range outbounds {
  151. if closer, isCloser := outbound.(io.Closer); isCloser {
  152. monitor.Start("close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  153. err = E.Append(err, closer.Close(), func(err error) error {
  154. return E.Cause(err, "close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  155. })
  156. monitor.Finish()
  157. }
  158. }
  159. return nil
  160. }
  161. func (m *Manager) Outbounds() []adapter.Outbound {
  162. m.access.RLock()
  163. defer m.access.RUnlock()
  164. return m.outbounds
  165. }
  166. func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
  167. m.access.RLock()
  168. outbound, found := m.outboundByTag[tag]
  169. m.access.RUnlock()
  170. if found {
  171. return outbound, true
  172. }
  173. return m.endpoint.Get(tag)
  174. }
  175. func (m *Manager) Default() adapter.Outbound {
  176. m.access.RLock()
  177. defer m.access.RUnlock()
  178. if m.defaultOutbound != nil {
  179. return m.defaultOutbound
  180. } else {
  181. return m.defaultOutboundFallback
  182. }
  183. }
  184. func (m *Manager) Remove(tag string) error {
  185. m.access.Lock()
  186. defer m.access.Unlock()
  187. outbound, found := m.outboundByTag[tag]
  188. if !found {
  189. return os.ErrInvalid
  190. }
  191. delete(m.outboundByTag, tag)
  192. index := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  193. return it == outbound
  194. })
  195. if index == -1 {
  196. panic("invalid inbound index")
  197. }
  198. m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...)
  199. started := m.started
  200. if m.defaultOutbound == outbound {
  201. if len(m.outbounds) > 0 {
  202. m.defaultOutbound = m.outbounds[0]
  203. m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
  204. } else {
  205. m.defaultOutbound = nil
  206. }
  207. }
  208. dependBy := m.dependByTag[tag]
  209. if len(dependBy) > 0 {
  210. return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", "))
  211. }
  212. dependencies := outbound.Dependencies()
  213. for _, dependency := range dependencies {
  214. if len(m.dependByTag[dependency]) == 1 {
  215. delete(m.dependByTag, dependency)
  216. } else {
  217. m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
  218. return it != tag
  219. })
  220. }
  221. }
  222. if started {
  223. return common.Close(outbound)
  224. }
  225. return nil
  226. }
  227. func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error {
  228. if tag == "" {
  229. return os.ErrInvalid
  230. }
  231. outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options)
  232. if err != nil {
  233. return err
  234. }
  235. if m.started {
  236. for _, stage := range adapter.ListStartStages {
  237. err = adapter.LegacyStart(outbound, stage)
  238. if err != nil {
  239. return E.Cause(err, stage, " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
  240. }
  241. }
  242. }
  243. m.access.Lock()
  244. defer m.access.Unlock()
  245. if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
  246. if m.started {
  247. err = common.Close(existsOutbound)
  248. if err != nil {
  249. return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]")
  250. }
  251. }
  252. existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  253. return it == existsOutbound
  254. })
  255. if existsIndex == -1 {
  256. panic("invalid inbound index")
  257. }
  258. m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...)
  259. }
  260. m.outbounds = append(m.outbounds, outbound)
  261. m.outboundByTag[tag] = outbound
  262. dependencies := outbound.Dependencies()
  263. for _, dependency := range dependencies {
  264. m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
  265. }
  266. if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) {
  267. m.defaultOutbound = outbound
  268. if m.started {
  269. m.logger.Info("updated default outbound to ", outbound.Tag())
  270. }
  271. }
  272. return nil
  273. }