manager.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. package outbound
  2. import (
  3. "context"
  4. "io"
  5. "os"
  6. "strings"
  7. "sync"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/taskmonitor"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing/common"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. )
  16. var _ adapter.OutboundManager = (*Manager)(nil)
  17. type Manager struct {
  18. logger log.ContextLogger
  19. registry adapter.OutboundRegistry
  20. endpoint adapter.EndpointManager
  21. defaultTag string
  22. access sync.RWMutex
  23. started bool
  24. stage adapter.StartStage
  25. outbounds []adapter.Outbound
  26. outboundByTag map[string]adapter.Outbound
  27. dependByTag map[string][]string
  28. defaultOutbound adapter.Outbound
  29. defaultOutboundFallback func() (adapter.Outbound, error)
  30. }
  31. func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
  32. return &Manager{
  33. logger: logger,
  34. registry: registry,
  35. endpoint: endpoint,
  36. defaultTag: defaultTag,
  37. outboundByTag: make(map[string]adapter.Outbound),
  38. dependByTag: make(map[string][]string),
  39. }
  40. }
  41. func (m *Manager) Initialize(defaultOutboundFallback func() (adapter.Outbound, error)) {
  42. m.defaultOutboundFallback = defaultOutboundFallback
  43. }
  44. func (m *Manager) Start(stage adapter.StartStage) error {
  45. m.access.Lock()
  46. if m.started && m.stage >= stage {
  47. panic("already started")
  48. }
  49. m.started = true
  50. m.stage = stage
  51. if stage == adapter.StartStateStart {
  52. if m.defaultTag != "" && m.defaultOutbound == nil {
  53. defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
  54. if !loaded {
  55. m.access.Unlock()
  56. return E.New("default outbound not found: ", m.defaultTag)
  57. }
  58. m.defaultOutbound = defaultEndpoint
  59. }
  60. if m.defaultOutbound == nil {
  61. directOutbound, err := m.defaultOutboundFallback()
  62. if err != nil {
  63. m.access.Unlock()
  64. return E.Cause(err, "create direct outbound for fallback")
  65. }
  66. m.outbounds = append(m.outbounds, directOutbound)
  67. m.outboundByTag[directOutbound.Tag()] = directOutbound
  68. m.defaultOutbound = directOutbound
  69. }
  70. outbounds := m.outbounds
  71. m.access.Unlock()
  72. return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
  73. } else {
  74. outbounds := m.outbounds
  75. m.access.Unlock()
  76. for _, outbound := range outbounds {
  77. name := "outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  78. done := adapter.LogElapsed(m.logger, stage, " ", name)
  79. err := adapter.LegacyStart(outbound, stage)
  80. done()
  81. if err != nil {
  82. return E.Cause(err, stage, " ", name)
  83. }
  84. }
  85. }
  86. return nil
  87. }
  88. func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error {
  89. monitor := taskmonitor.New(m.logger, C.StartTimeout)
  90. started := make(map[string]bool)
  91. for {
  92. canContinue := false
  93. startOne:
  94. for _, outboundToStart := range outbounds {
  95. outboundTag := outboundToStart.Tag()
  96. if started[outboundTag] {
  97. continue
  98. }
  99. dependencies := outboundToStart.Dependencies()
  100. for _, dependency := range dependencies {
  101. if !started[dependency] {
  102. continue startOne
  103. }
  104. }
  105. started[outboundTag] = true
  106. canContinue = true
  107. name := "outbound/" + outboundToStart.Type() + "[" + outboundTag + "]"
  108. if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter {
  109. done := adapter.LogElapsed(m.logger, "start ", name)
  110. monitor.Start("start ", name)
  111. err := starter.Start(adapter.StartStateStart)
  112. monitor.Finish()
  113. done()
  114. if err != nil {
  115. return E.Cause(err, "start ", name)
  116. }
  117. } else if starter, isStarter := outboundToStart.(interface {
  118. Start() error
  119. }); isStarter {
  120. done := adapter.LogElapsed(m.logger, "start ", name)
  121. monitor.Start("start ", name)
  122. err := starter.Start()
  123. monitor.Finish()
  124. done()
  125. if err != nil {
  126. return E.Cause(err, "start ", name)
  127. }
  128. }
  129. }
  130. if len(started) == len(outbounds) {
  131. break
  132. }
  133. if canContinue {
  134. continue
  135. }
  136. currentOutbound := common.Find(outbounds, func(it adapter.Outbound) bool {
  137. return !started[it.Tag()]
  138. })
  139. var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
  140. lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
  141. problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
  142. return !started[it]
  143. })
  144. if common.Contains(oTree, problemOutboundTag) {
  145. return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
  146. }
  147. m.access.Lock()
  148. problemOutbound := m.outboundByTag[problemOutboundTag]
  149. m.access.Unlock()
  150. if problemOutbound == nil {
  151. return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]")
  152. }
  153. return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
  154. }
  155. return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound)
  156. }
  157. return nil
  158. }
  159. func (m *Manager) Close() error {
  160. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  161. m.access.Lock()
  162. if !m.started {
  163. m.access.Unlock()
  164. return nil
  165. }
  166. m.started = false
  167. outbounds := m.outbounds
  168. m.outbounds = nil
  169. m.access.Unlock()
  170. var err error
  171. for _, outbound := range outbounds {
  172. if closer, isCloser := outbound.(io.Closer); isCloser {
  173. name := "outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  174. done := adapter.LogElapsed(m.logger, "close ", name)
  175. monitor.Start("close ", name)
  176. err = E.Append(err, closer.Close(), func(err error) error {
  177. return E.Cause(err, "close ", name)
  178. })
  179. monitor.Finish()
  180. done()
  181. }
  182. }
  183. return nil
  184. }
  185. func (m *Manager) Outbounds() []adapter.Outbound {
  186. m.access.RLock()
  187. defer m.access.RUnlock()
  188. return m.outbounds
  189. }
  190. func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
  191. m.access.RLock()
  192. outbound, found := m.outboundByTag[tag]
  193. m.access.RUnlock()
  194. if found {
  195. return outbound, true
  196. }
  197. return m.endpoint.Get(tag)
  198. }
  199. func (m *Manager) Default() adapter.Outbound {
  200. m.access.RLock()
  201. defer m.access.RUnlock()
  202. return m.defaultOutbound
  203. }
  204. func (m *Manager) Remove(tag string) error {
  205. m.access.Lock()
  206. defer m.access.Unlock()
  207. outbound, found := m.outboundByTag[tag]
  208. if !found {
  209. return os.ErrInvalid
  210. }
  211. delete(m.outboundByTag, tag)
  212. index := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  213. return it == outbound
  214. })
  215. if index == -1 {
  216. panic("invalid inbound index")
  217. }
  218. m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...)
  219. started := m.started
  220. if m.defaultOutbound == outbound {
  221. if len(m.outbounds) > 0 {
  222. m.defaultOutbound = m.outbounds[0]
  223. m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
  224. } else {
  225. m.defaultOutbound = nil
  226. }
  227. }
  228. dependBy := m.dependByTag[tag]
  229. if len(dependBy) > 0 {
  230. return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", "))
  231. }
  232. dependencies := outbound.Dependencies()
  233. for _, dependency := range dependencies {
  234. if len(m.dependByTag[dependency]) == 1 {
  235. delete(m.dependByTag, dependency)
  236. } else {
  237. m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
  238. return it != tag
  239. })
  240. }
  241. }
  242. if started {
  243. return common.Close(outbound)
  244. }
  245. return nil
  246. }
  247. func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error {
  248. if tag == "" {
  249. return os.ErrInvalid
  250. }
  251. outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options)
  252. if err != nil {
  253. return err
  254. }
  255. if m.started {
  256. name := "outbound/" + outbound.Type() + "[" + outbound.Tag() + "]"
  257. for _, stage := range adapter.ListStartStages {
  258. done := adapter.LogElapsed(m.logger, stage, " ", name)
  259. err = adapter.LegacyStart(outbound, stage)
  260. done()
  261. if err != nil {
  262. return E.Cause(err, stage, " ", name)
  263. }
  264. }
  265. }
  266. m.access.Lock()
  267. defer m.access.Unlock()
  268. if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
  269. if m.started {
  270. err = common.Close(existsOutbound)
  271. if err != nil {
  272. return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]")
  273. }
  274. }
  275. existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool {
  276. return it == existsOutbound
  277. })
  278. if existsIndex == -1 {
  279. panic("invalid inbound index")
  280. }
  281. m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...)
  282. }
  283. m.outbounds = append(m.outbounds, outbound)
  284. m.outboundByTag[tag] = outbound
  285. dependencies := outbound.Dependencies()
  286. for _, dependency := range dependencies {
  287. m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
  288. }
  289. if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) {
  290. m.defaultOutbound = outbound
  291. if m.started {
  292. m.logger.Info("updated default outbound to ", outbound.Tag())
  293. }
  294. }
  295. return nil
  296. }