transport_manager.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package dns
  2. import (
  3. "context"
  4. "io"
  5. "os"
  6. "strings"
  7. "sync"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/taskmonitor"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing/common"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. )
  16. var _ adapter.DNSTransportManager = (*TransportManager)(nil)
  17. type TransportManager struct {
  18. logger log.ContextLogger
  19. registry adapter.DNSTransportRegistry
  20. outbound adapter.OutboundManager
  21. defaultTag string
  22. access sync.RWMutex
  23. started bool
  24. stage adapter.StartStage
  25. transports []adapter.DNSTransport
  26. transportByTag map[string]adapter.DNSTransport
  27. dependByTag map[string][]string
  28. defaultTransport adapter.DNSTransport
  29. defaultTransportFallback func() (adapter.DNSTransport, error)
  30. fakeIPTransport adapter.FakeIPTransport
  31. }
  32. func NewTransportManager(logger logger.ContextLogger, registry adapter.DNSTransportRegistry, outbound adapter.OutboundManager, defaultTag string) *TransportManager {
  33. return &TransportManager{
  34. logger: logger,
  35. registry: registry,
  36. outbound: outbound,
  37. defaultTag: defaultTag,
  38. transportByTag: make(map[string]adapter.DNSTransport),
  39. dependByTag: make(map[string][]string),
  40. }
  41. }
  42. func (m *TransportManager) Initialize(defaultTransportFallback func() (adapter.DNSTransport, error)) {
  43. m.defaultTransportFallback = defaultTransportFallback
  44. }
  45. func (m *TransportManager) Start(stage adapter.StartStage) error {
  46. m.access.Lock()
  47. if m.started && m.stage >= stage {
  48. panic("already started")
  49. }
  50. m.started = true
  51. m.stage = stage
  52. if stage == adapter.StartStateStart {
  53. if m.defaultTag != "" && m.defaultTransport == nil {
  54. m.access.Unlock()
  55. return E.New("default DNS server not found: ", m.defaultTag)
  56. }
  57. if m.defaultTransport == nil {
  58. defaultTransport, err := m.defaultTransportFallback()
  59. if err != nil {
  60. m.access.Unlock()
  61. return E.Cause(err, "default DNS server fallback")
  62. }
  63. m.transports = append(m.transports, defaultTransport)
  64. m.transportByTag[defaultTransport.Tag()] = defaultTransport
  65. m.defaultTransport = defaultTransport
  66. }
  67. transports := m.transports
  68. m.access.Unlock()
  69. return m.startTransports(transports)
  70. } else {
  71. transports := m.transports
  72. m.access.Unlock()
  73. for _, outbound := range transports {
  74. err := adapter.LegacyStart(outbound, stage)
  75. if err != nil {
  76. return E.Cause(err, stage, " dns/", outbound.Type(), "[", outbound.Tag(), "]")
  77. }
  78. }
  79. }
  80. return nil
  81. }
  82. func (m *TransportManager) startTransports(transports []adapter.DNSTransport) error {
  83. monitor := taskmonitor.New(m.logger, C.StartTimeout)
  84. started := make(map[string]bool)
  85. for {
  86. canContinue := false
  87. startOne:
  88. for _, transportToStart := range transports {
  89. transportTag := transportToStart.Tag()
  90. if started[transportTag] {
  91. continue
  92. }
  93. dependencies := transportToStart.Dependencies()
  94. for _, dependency := range dependencies {
  95. if !started[dependency] {
  96. continue startOne
  97. }
  98. }
  99. started[transportTag] = true
  100. canContinue = true
  101. if starter, isStarter := transportToStart.(adapter.Lifecycle); isStarter {
  102. monitor.Start("start dns/", transportToStart.Type(), "[", transportTag, "]")
  103. err := starter.Start(adapter.StartStateStart)
  104. monitor.Finish()
  105. if err != nil {
  106. return E.Cause(err, "start dns/", transportToStart.Type(), "[", transportTag, "]")
  107. }
  108. }
  109. }
  110. if len(started) == len(transports) {
  111. break
  112. }
  113. if canContinue {
  114. continue
  115. }
  116. currentTransport := common.Find(transports, func(it adapter.DNSTransport) bool {
  117. return !started[it.Tag()]
  118. })
  119. var lintTransport func(oTree []string, oCurrent adapter.DNSTransport) error
  120. lintTransport = func(oTree []string, oCurrent adapter.DNSTransport) error {
  121. problemTransportTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
  122. return !started[it]
  123. })
  124. if common.Contains(oTree, problemTransportTag) {
  125. return E.New("circular server dependency: ", strings.Join(oTree, " -> "), " -> ", problemTransportTag)
  126. }
  127. m.access.Lock()
  128. problemTransport := m.transportByTag[problemTransportTag]
  129. m.access.Unlock()
  130. if problemTransport == nil {
  131. return E.New("dependency[", problemTransportTag, "] not found for server[", oCurrent.Tag(), "]")
  132. }
  133. return lintTransport(append(oTree, problemTransportTag), problemTransport)
  134. }
  135. return lintTransport([]string{currentTransport.Tag()}, currentTransport)
  136. }
  137. return nil
  138. }
  139. func (m *TransportManager) Close() error {
  140. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  141. m.access.Lock()
  142. if !m.started {
  143. m.access.Unlock()
  144. return nil
  145. }
  146. m.started = false
  147. transports := m.transports
  148. m.transports = nil
  149. m.access.Unlock()
  150. var err error
  151. for _, transport := range transports {
  152. if closer, isCloser := transport.(io.Closer); isCloser {
  153. monitor.Start("close server/", transport.Type(), "[", transport.Tag(), "]")
  154. err = E.Append(err, closer.Close(), func(err error) error {
  155. return E.Cause(err, "close server/", transport.Type(), "[", transport.Tag(), "]")
  156. })
  157. monitor.Finish()
  158. }
  159. }
  160. return nil
  161. }
  162. func (m *TransportManager) Transports() []adapter.DNSTransport {
  163. m.access.RLock()
  164. defer m.access.RUnlock()
  165. return m.transports
  166. }
  167. func (m *TransportManager) Transport(tag string) (adapter.DNSTransport, bool) {
  168. m.access.RLock()
  169. outbound, found := m.transportByTag[tag]
  170. m.access.RUnlock()
  171. return outbound, found
  172. }
  173. func (m *TransportManager) Default() adapter.DNSTransport {
  174. m.access.RLock()
  175. defer m.access.RUnlock()
  176. return m.defaultTransport
  177. }
  178. func (m *TransportManager) FakeIP() adapter.FakeIPTransport {
  179. m.access.RLock()
  180. defer m.access.RUnlock()
  181. return m.fakeIPTransport
  182. }
  183. func (m *TransportManager) Remove(tag string) error {
  184. m.access.Lock()
  185. defer m.access.Unlock()
  186. transport, found := m.transportByTag[tag]
  187. if !found {
  188. return os.ErrInvalid
  189. }
  190. delete(m.transportByTag, tag)
  191. index := common.Index(m.transports, func(it adapter.DNSTransport) bool {
  192. return it == transport
  193. })
  194. if index == -1 {
  195. panic("invalid inbound index")
  196. }
  197. m.transports = append(m.transports[:index], m.transports[index+1:]...)
  198. started := m.started
  199. if m.defaultTransport == transport {
  200. if len(m.transports) > 0 {
  201. nextTransport := m.transports[0]
  202. if nextTransport.Type() != C.DNSTypeFakeIP {
  203. return E.New("default server cannot be fakeip")
  204. }
  205. m.defaultTransport = nextTransport
  206. m.logger.Info("updated default server to ", m.defaultTransport.Tag())
  207. } else {
  208. m.defaultTransport = nil
  209. }
  210. }
  211. dependBy := m.dependByTag[tag]
  212. if len(dependBy) > 0 {
  213. return E.New("server[", tag, "] is depended by ", strings.Join(dependBy, ", "))
  214. }
  215. dependencies := transport.Dependencies()
  216. for _, dependency := range dependencies {
  217. if len(m.dependByTag[dependency]) == 1 {
  218. delete(m.dependByTag, dependency)
  219. } else {
  220. m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
  221. return it != tag
  222. })
  223. }
  224. }
  225. if started {
  226. transport.Close()
  227. }
  228. return nil
  229. }
  230. func (m *TransportManager) Create(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) error {
  231. if tag == "" {
  232. return os.ErrInvalid
  233. }
  234. transport, err := m.registry.CreateDNSTransport(ctx, logger, tag, transportType, options)
  235. if err != nil {
  236. return err
  237. }
  238. m.access.Lock()
  239. defer m.access.Unlock()
  240. if m.started {
  241. for _, stage := range adapter.ListStartStages {
  242. err = adapter.LegacyStart(transport, stage)
  243. if err != nil {
  244. return E.Cause(err, stage, " dns/", transport.Type(), "[", transport.Tag(), "]")
  245. }
  246. }
  247. }
  248. if existsTransport, loaded := m.transportByTag[tag]; loaded {
  249. if m.started {
  250. err = common.Close(existsTransport)
  251. if err != nil {
  252. return E.Cause(err, "close dns/", existsTransport.Type(), "[", existsTransport.Tag(), "]")
  253. }
  254. }
  255. existsIndex := common.Index(m.transports, func(it adapter.DNSTransport) bool {
  256. return it == existsTransport
  257. })
  258. if existsIndex == -1 {
  259. panic("invalid inbound index")
  260. }
  261. m.transports = append(m.transports[:existsIndex], m.transports[existsIndex+1:]...)
  262. }
  263. m.transports = append(m.transports, transport)
  264. m.transportByTag[tag] = transport
  265. dependencies := transport.Dependencies()
  266. for _, dependency := range dependencies {
  267. m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
  268. }
  269. if tag == m.defaultTag || (m.defaultTag == "" && m.defaultTransport == nil) {
  270. if transport.Type() == C.DNSTypeFakeIP {
  271. return E.New("default server cannot be fakeip")
  272. }
  273. m.defaultTransport = transport
  274. if m.started {
  275. m.logger.Info("updated default server to ", transport.Tag())
  276. }
  277. }
  278. if transport.Type() == C.DNSTypeFakeIP {
  279. if m.fakeIPTransport != nil {
  280. return E.New("multiple fakeip server are not supported")
  281. }
  282. m.fakeIPTransport = transport.(adapter.FakeIPTransport)
  283. }
  284. return nil
  285. }