manager.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package outbound
  2. import (
  3. "context"
  4. "io"
  5. "os"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/sagernet/sing-box/adapter"
  10. "github.com/sagernet/sing-box/common/taskmonitor"
  11. C "github.com/sagernet/sing-box/constant"
  12. "github.com/sagernet/sing-box/log"
  13. "github.com/sagernet/sing/common"
  14. E "github.com/sagernet/sing/common/exceptions"
  15. F "github.com/sagernet/sing/common/format"
  16. "github.com/sagernet/sing/common/logger"
  17. )
  18. var _ adapter.OutboundManager = (*Manager)(nil)
  19. type Manager struct {
  20. logger log.ContextLogger
  21. registry adapter.OutboundRegistry
  22. endpoint adapter.EndpointManager
  23. defaultTag string
  24. access sync.RWMutex
  25. started bool
  26. stage adapter.StartStage
  27. outbounds []adapter.Outbound
  28. outboundByTag map[string]adapter.Outbound
  29. dependByTag map[string][]string
  30. defaultOutbound adapter.Outbound
  31. defaultOutboundFallback func() (adapter.Outbound, error)
  32. }
  33. func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
  34. return &Manager{
  35. logger: logger,
  36. registry: registry,
  37. endpoint: endpoint,
  38. defaultTag: defaultTag,
  39. outboundByTag: make(map[string]adapter.Outbound),
  40. dependByTag: make(map[string][]string),
  41. }
  42. }
  43. func (m *Manager) Initialize(defaultOutboundFallback func() (adapter.Outbound, error)) {
  44. m.defaultOutboundFallback = defaultOutboundFallback
  45. }
  46. func (m *Manager) Start(stage adapter.StartStage) error {
  47. m.access.Lock()
  48. if m.started && m.stage >= stage {
  49. panic("already started")
  50. }
  51. m.started = true
  52. m.stage = stage
  53. if stage == adapter.StartStateStart {
  54. if m.defaultTag != "" && m.defaultOutbound == nil {
  55. defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
  56. if !loaded {
  57. m.access.Unlock()
  58. return E.New("default outbound not found: ", m.defaultTag)
  59. }
  60. m.defaultOutbound = defaultEndpoint
  61. }
  62. if m.defaultOutbound == nil {
  63. directOutbound, err := m.defaultOutboundFallback()
  64. if err != nil {
  65. m.access.Unlock()
  66. return E.Cause(err, "create direct outbound for fallback")
  67. }
  68. m.outbounds = append(m.outbounds, directOutbound)
  69. m.outboundByTag[directOutbound.Tag()] = directOutbound
  70. m.defaultOutbound = directOutbound
  71. }
  72. outbounds := m.outbounds
  73. m.access.Unlock()
  74. return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
  75. } else {
  76. outbounds := m.outbounds
  77. m.access.Unlock()
  78. for _, outbound := range outbounds {
  79. name := "outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  80. m.logger.Trace(stage, " ", name)
  81. startTime := time.Now()
  82. err := adapter.LegacyStart(outbound, stage)
  83. if err != nil {
  84. return E.Cause(err, stage, " ", name)
  85. }
  86. m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  87. }
  88. }
  89. return nil
  90. }
  91. func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
  92. monitor := taskmonitor.New(m.logger, C.StartTimeout)
  93. started := make(map[string]bool)
  94. for {
  95. canContinue := false
  96. startOne:
  97. for _, outboundToStart := range outbounds {
  98. outboundTag := outboundToStart.Tag()
  99. if started[outboundTag] {
  100. continue
  101. }
  102. dependencies := outboundToStart.Dependencies()
  103. for _, dependency := range dependencies {
  104. if !started[dependency] {
  105. continue startOne
  106. }
  107. }
  108. started[outboundTag] = true
  109. canContinue = true
  110. name := "outbound/" + outboundToStart.Type() + "[" + outboundTag + "]"
  111. if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
  112. m.logger.Trace("start ", name)
  113. startTime := time.Now()
  114. monitor.Start("start ", name)
  115. err := starter.Start(adapter.StartStateStart)
  116. monitor.Finish()
  117. if err != nil {
  118. return E.Cause(err, "start ", name)
  119. }
  120. m.logger.Trace("start ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  121. } else if starter, isStarter := outboundToStart.(interface {
  122. Start() error
  123. }); isStarter {
  124. m.logger.Trace("start ", name)
  125. startTime := time.Now()
  126. monitor.Start("start ", name)
  127. err := starter.Start()
  128. monitor.Finish()
  129. if err != nil {
  130. return E.Cause(err, "start ", name)
  131. }
  132. m.logger.Trace("start ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  133. }
  134. }
  135. if len(started) == len(outbounds) {
  136. break
  137. }
  138. if canContinue {
  139. continue
  140. }
  141. currentOutbound := common.Find(outbounds, func(it adapter.Outbound) bool {
  142. return !started[it.Tag()]
  143. })
  144. var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
  145. lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
  146. problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
  147. return !started[it]
  148. })
  149. if common.Contains(oTree, problemOutboundTag) {
  150. return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
  151. }
  152. m.access.Lock()
  153. problemOutbound := m.outboundByTag[problemOutboundTag]
  154. m.access.Unlock()
  155. if problemOutbound == nil {
  156. return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]")
  157. }
  158. return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
  159. }
  160. return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound)
  161. }
  162. return nil
  163. }
  164. func (m *Manager) Close() error {
  165. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  166. m.access.Lock()
  167. if !m.started {
  168. m.access.Unlock()
  169. return nil
  170. }
  171. m.started = false
  172. outbounds := m.outbounds
  173. m.outbounds = nil
  174. m.access.Unlock()
  175. var err error
  176. for _, outbound := range outbounds {
  177. if closer, isCloser := outbound.(io.Closer); isCloser {
  178. name := "outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  179. m.logger.Trace("close ", name)
  180. startTime := time.Now()
  181. monitor.Start("close ", name)
  182. err = E.Append(err, closer.Close(), func(err error) error {
  183. return E.Cause(err, "close ", name)
  184. })
  185. monitor.Finish()
  186. m.logger.Trace("close ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  187. }
  188. }
  189. return nil
  190. }
  191. func (m *Manager) Outbounds() []adapter.Outbound {
  192. m.access.RLock()
  193. defer m.access.RUnlock()
  194. return m.outbounds
  195. }
  196. func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
  197. m.access.RLock()
  198. outbound, found := m.outboundByTag[tag]
  199. m.access.RUnlock()
  200. if found {
  201. return outbound, true
  202. }
  203. return m.endpoint.Get(tag)
  204. }
  205. func (m *Manager) Default() adapter.Outbound {
  206. m.access.RLock()
  207. defer m.access.RUnlock()
  208. return m.defaultOutbound
  209. }
  210. func (m *Manager) Remove(tag string) error {
  211. m.access.Lock()
  212. defer m.access.Unlock()
  213. outbound, found := m.outboundByTag[tag]
  214. if !found {
  215. return os.ErrInvalid
  216. }
  217. delete(m.outboundByTag, tag)
  218. index := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  219. return it == outbound
  220. })
  221. if index == -1 {
  222. panic("invalid inbound index")
  223. }
  224. m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...)
  225. started := m.started
  226. if m.defaultOutbound == outbound {
  227. if len(m.outbounds) > 0 {
  228. m.defaultOutbound = m.outbounds[0]
  229. m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
  230. } else {
  231. m.defaultOutbound = nil
  232. }
  233. }
  234. dependBy := m.dependByTag[tag]
  235. if len(dependBy) > 0 {
  236. return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", "))
  237. }
  238. dependencies := outbound.Dependencies()
  239. for _, dependency := range dependencies {
  240. if len(m.dependByTag[dependency]) == 1 {
  241. delete(m.dependByTag, dependency)
  242. } else {
  243. m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
  244. return it != tag
  245. })
  246. }
  247. }
  248. if started {
  249. return common.Close(outbound)
  250. }
  251. return nil
  252. }
  253. func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error {
  254. if tag == "" {
  255. return os.ErrInvalid
  256. }
  257. outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options)
  258. if err != nil {
  259. return err
  260. }
  261. if m.started {
  262. name := "outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  263. for _, stage := range adapter.ListStartStages {
  264. m.logger.Trace(stage, " ", name)
  265. startTime := time.Now()
  266. err = adapter.LegacyStart(outbound, stage)
  267. if err != nil {
  268. return E.Cause(err, stage, " ", name)
  269. }
  270. m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
  271. }
  272. }
  273. m.access.Lock()
  274. defer m.access.Unlock()
  275. if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
  276. if m.started {
  277. err = common.Close(existsOutbound)
  278. if err != nil {
  279. return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]")
  280. }
  281. }
  282. existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  283. return it == existsOutbound
  284. })
  285. if existsIndex == -1 {
  286. panic("invalid inbound index")
  287. }
  288. m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...)
  289. }
  290. m.outbounds = append(m.outbounds, outbound)
  291. m.outboundByTag[tag] = outbound
  292. dependencies := outbound.Dependencies()
  293. for _, dependency := range dependencies {
  294. m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
  295. }
  296. if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) {
  297. m.defaultOutbound = outbound
  298. if m.started {
  299. m.logger.Info("updated default outbound to ", outbound.Tag())
  300. }
  301. }
  302. return nil
  303. }