transport_manager.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. package dns
  2. import (
  3. "context"
  4. "io"
  5. "os"
  6. "strings"
  7. "sync"
  8. "github.com/sagernet/sing-box/adapter"
  9. "github.com/sagernet/sing-box/common/taskmonitor"
  10. C "github.com/sagernet/sing-box/constant"
  11. "github.com/sagernet/sing-box/log"
  12. "github.com/sagernet/sing/common"
  13. E "github.com/sagernet/sing/common/exceptions"
  14. "github.com/sagernet/sing/common/logger"
  15. )
  16. var _ adapter.DNSTransportManager = (*TransportManager)(nil)
  17. type TransportManager struct {
  18. logger log.ContextLogger
  19. registry adapter.DNSTransportRegistry
  20. outbound adapter.OutboundManager
  21. defaultTag string
  22. access sync.RWMutex
  23. started bool
  24. stage adapter.StartStage
  25. transports []adapter.DNSTransport
  26. transportByTag map[string]adapter.DNSTransport
  27. dependByTag map[string][]string
  28. defaultTransport adapter.DNSTransport
  29. defaultTransportFallback adapter.DNSTransport
  30. fakeIPTransport adapter.FakeIPTransport
  31. }
  32. func NewTransportManager(logger logger.ContextLogger, registry adapter.DNSTransportRegistry, outbound adapter.OutboundManager, defaultTag string) *TransportManager {
  33. return &TransportManager{
  34. logger: logger,
  35. registry: registry,
  36. outbound: outbound,
  37. defaultTag: defaultTag,
  38. transportByTag: make(map[string]adapter.DNSTransport),
  39. dependByTag: make(map[string][]string),
  40. }
  41. }
  42. func (m *TransportManager) Initialize(defaultTransportFallback adapter.DNSTransport) {
  43. m.defaultTransportFallback = defaultTransportFallback
  44. }
  45. func (m *TransportManager) Start(stage adapter.StartStage) error {
  46. m.access.Lock()
  47. if m.started && m.stage >= stage {
  48. panic("already started")
  49. }
  50. m.started = true
  51. m.stage = stage
  52. transports := m.transports
  53. m.access.Unlock()
  54. if stage == adapter.StartStateStart {
  55. if m.defaultTag != "" && m.defaultTransport == nil {
  56. return E.New("default DNS server not found: ", m.defaultTag)
  57. }
  58. return m.startTransports(m.transports)
  59. } else {
  60. for _, outbound := range transports {
  61. err := adapter.LegacyStart(outbound, stage)
  62. if err != nil {
  63. return E.Cause(err, stage, " dns/", outbound.Type(), "[", outbound.Tag(), "]")
  64. }
  65. }
  66. }
  67. return nil
  68. }
  69. func (m *TransportManager) startTransports(transports []adapter.DNSTransport) error {
  70. monitor := taskmonitor.New(m.logger, C.StartTimeout)
  71. started := make(map[string]bool)
  72. for {
  73. canContinue := false
  74. startOne:
  75. for _, transportToStart := range transports {
  76. transportTag := transportToStart.Tag()
  77. if started[transportTag] {
  78. continue
  79. }
  80. dependencies := transportToStart.Dependencies()
  81. for _, dependency := range dependencies {
  82. if !started[dependency] {
  83. continue startOne
  84. }
  85. }
  86. started[transportTag] = true
  87. canContinue = true
  88. if starter, isStarter := transportToStart.(adapter.Lifecycle); isStarter {
  89. monitor.Start("start dns/", transportToStart.Type(), "[", transportTag, "]")
  90. err := starter.Start(adapter.StartStateStart)
  91. monitor.Finish()
  92. if err != nil {
  93. return E.Cause(err, "start dns/", transportToStart.Type(), "[", transportTag, "]")
  94. }
  95. }
  96. }
  97. if len(started) == len(transports) {
  98. break
  99. }
  100. if canContinue {
  101. continue
  102. }
  103. currentTransport := common.Find(transports, func(it adapter.DNSTransport) bool {
  104. return !started[it.Tag()]
  105. })
  106. var lintTransport func(oTree []string, oCurrent adapter.DNSTransport) error
  107. lintTransport = func(oTree []string, oCurrent adapter.DNSTransport) error {
  108. problemTransportTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
  109. return !started[it]
  110. })
  111. if common.Contains(oTree, problemTransportTag) {
  112. return E.New("circular server dependency: ", strings.Join(oTree, " -> "), " -> ", problemTransportTag)
  113. }
  114. m.access.Lock()
  115. problemTransport := m.transportByTag[problemTransportTag]
  116. m.access.Unlock()
  117. if problemTransport == nil {
  118. return E.New("dependency[", problemTransportTag, "] not found for server[", oCurrent.Tag(), "]")
  119. }
  120. return lintTransport(append(oTree, problemTransportTag), problemTransport)
  121. }
  122. return lintTransport([]string{currentTransport.Tag()}, currentTransport)
  123. }
  124. return nil
  125. }
  126. func (m *TransportManager) Close() error {
  127. monitor := taskmonitor.New(m.logger, C.StopTimeout)
  128. m.access.Lock()
  129. if !m.started {
  130. m.access.Unlock()
  131. return nil
  132. }
  133. m.started = false
  134. transports := m.transports
  135. m.transports = nil
  136. m.access.Unlock()
  137. var err error
  138. for _, transport := range transports {
  139. if closer, isCloser := transport.(io.Closer); isCloser {
  140. monitor.Start("close server/", transport.Type(), "[", transport.Tag(), "]")
  141. err = E.Append(err, closer.Close(), func(err error) error {
  142. return E.Cause(err, "close server/", transport.Type(), "[", transport.Tag(), "]")
  143. })
  144. monitor.Finish()
  145. }
  146. }
  147. return nil
  148. }
  149. func (m *TransportManager) Transports() []adapter.DNSTransport {
  150. m.access.RLock()
  151. defer m.access.RUnlock()
  152. return m.transports
  153. }
  154. func (m *TransportManager) Transport(tag string) (adapter.DNSTransport, bool) {
  155. m.access.RLock()
  156. outbound, found := m.transportByTag[tag]
  157. m.access.RUnlock()
  158. return outbound, found
  159. }
  160. func (m *TransportManager) Default() adapter.DNSTransport {
  161. m.access.RLock()
  162. defer m.access.RUnlock()
  163. if m.defaultTransport != nil {
  164. return m.defaultTransport
  165. } else {
  166. return m.defaultTransportFallback
  167. }
  168. }
  169. func (m *TransportManager) FakeIP() adapter.FakeIPTransport {
  170. m.access.RLock()
  171. defer m.access.RUnlock()
  172. return m.fakeIPTransport
  173. }
  174. func (m *TransportManager) Remove(tag string) error {
  175. m.access.Lock()
  176. defer m.access.Unlock()
  177. transport, found := m.transportByTag[tag]
  178. if !found {
  179. return os.ErrInvalid
  180. }
  181. delete(m.transportByTag, tag)
  182. index := common.Index(m.transports, func(it adapter.DNSTransport) bool {
  183. return it == transport
  184. })
  185. if index == -1 {
  186. panic("invalid inbound index")
  187. }
  188. m.transports = append(m.transports[:index], m.transports[index+1:]...)
  189. started := m.started
  190. if m.defaultTransport == transport {
  191. if len(m.transports) > 0 {
  192. nextTransport := m.transports[0]
  193. if nextTransport.Type() != C.DNSTypeFakeIP {
  194. return E.New("default server cannot be fakeip")
  195. }
  196. m.defaultTransport = nextTransport
  197. m.logger.Info("updated default server to ", m.defaultTransport.Tag())
  198. } else {
  199. m.defaultTransport = nil
  200. }
  201. }
  202. dependBy := m.dependByTag[tag]
  203. if len(dependBy) > 0 {
  204. return E.New("server[", tag, "] is depended by ", strings.Join(dependBy, ", "))
  205. }
  206. dependencies := transport.Dependencies()
  207. for _, dependency := range dependencies {
  208. if len(m.dependByTag[dependency]) == 1 {
  209. delete(m.dependByTag, dependency)
  210. } else {
  211. m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
  212. return it != tag
  213. })
  214. }
  215. }
  216. if started {
  217. transport.Close()
  218. }
  219. return nil
  220. }
  221. func (m *TransportManager) Create(ctx context.Context, logger log.ContextLogger, tag string, transportType string, options any) error {
  222. if tag == "" {
  223. return os.ErrInvalid
  224. }
  225. transport, err := m.registry.CreateDNSTransport(ctx, logger, tag, transportType, options)
  226. if err != nil {
  227. return err
  228. }
  229. m.access.Lock()
  230. defer m.access.Unlock()
  231. if m.started {
  232. for _, stage := range adapter.ListStartStages {
  233. err = adapter.LegacyStart(transport, stage)
  234. if err != nil {
  235. return E.Cause(err, stage, " dns/", transport.Type(), "[", transport.Tag(), "]")
  236. }
  237. }
  238. }
  239. if existsTransport, loaded := m.transportByTag[tag]; loaded {
  240. if m.started {
  241. err = common.Close(existsTransport)
  242. if err != nil {
  243. return E.Cause(err, "close dns/", existsTransport.Type(), "[", existsTransport.Tag(), "]")
  244. }
  245. }
  246. existsIndex := common.Index(m.transports, func(it adapter.DNSTransport) bool {
  247. return it == existsTransport
  248. })
  249. if existsIndex == -1 {
  250. panic("invalid inbound index")
  251. }
  252. m.transports = append(m.transports[:existsIndex], m.transports[existsIndex+1:]...)
  253. }
  254. m.transports = append(m.transports, transport)
  255. m.transportByTag[tag] = transport
  256. dependencies := transport.Dependencies()
  257. for _, dependency := range dependencies {
  258. m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
  259. }
  260. if tag == m.defaultTag || (m.defaultTag == "" && m.defaultTransport == nil) {
  261. if transport.Type() == C.DNSTypeFakeIP {
  262. return E.New("default server cannot be fakeip")
  263. }
  264. m.defaultTransport = transport
  265. if m.started {
  266. m.logger.Info("updated default server to ", transport.Tag())
  267. }
  268. }
  269. if transport.Type() == C.DNSTypeFakeIP {
  270. if m.fakeIPTransport != nil {
  271. return E.New("multiple fakeip server are not supported")
  272. }
  273. m.fakeIPTransport = transport.(adapter.FakeIPTransport)
  274. }
  275. return nil
  276. }