dependencies.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. /*
  2. Copyright 2020 Docker Compose CLI authors
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package compose
  14. import (
  15. "context"
  16. "fmt"
  17. "strings"
  18. "sync"
  19. "github.com/compose-spec/compose-go/types"
  20. "golang.org/x/sync/errgroup"
  21. "github.com/docker/compose/v2/pkg/utils"
  22. )
  23. // ServiceStatus indicates the status of a service
  24. type ServiceStatus int
  25. // Services status flags
  26. const (
  27. ServiceStopped ServiceStatus = iota
  28. ServiceStarted
  29. )
  30. type graphTraversal struct {
  31. mu sync.Mutex
  32. seen map[string]struct{}
  33. ignored map[string]struct{}
  34. extremityNodesFn func(*Graph) []*Vertex // leaves or roots
  35. adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren
  36. filterAdjacentByStatusFn func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents
  37. targetServiceStatus ServiceStatus
  38. adjacentServiceStatusToSkip ServiceStatus
  39. visitorFn func(context.Context, string) error
  40. maxConcurrency int
  41. }
  42. func upDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
  43. return &graphTraversal{
  44. extremityNodesFn: leaves,
  45. adjacentNodesFn: getParents,
  46. filterAdjacentByStatusFn: filterChildren,
  47. adjacentServiceStatusToSkip: ServiceStopped,
  48. targetServiceStatus: ServiceStarted,
  49. visitorFn: visitorFn,
  50. }
  51. }
  52. func downDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
  53. return &graphTraversal{
  54. extremityNodesFn: roots,
  55. adjacentNodesFn: getChildren,
  56. filterAdjacentByStatusFn: filterParents,
  57. adjacentServiceStatusToSkip: ServiceStarted,
  58. targetServiceStatus: ServiceStopped,
  59. visitorFn: visitorFn,
  60. }
  61. }
  62. // InDependencyOrder applies the function to the services of the project taking in account the dependency order
  63. func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
  64. graph, err := NewGraph(project.Services, ServiceStopped)
  65. if err != nil {
  66. return err
  67. }
  68. t := upDirectionTraversal(fn)
  69. for _, option := range options {
  70. option(t)
  71. }
  72. return t.visit(ctx, graph)
  73. }
  74. // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
  75. func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
  76. graph, err := NewGraph(project.Services, ServiceStarted)
  77. if err != nil {
  78. return err
  79. }
  80. t := downDirectionTraversal(fn)
  81. for _, option := range options {
  82. option(t)
  83. }
  84. return t.visit(ctx, graph)
  85. }
  86. func WithRootNodesAndDown(nodes []string) func(*graphTraversal) {
  87. return func(t *graphTraversal) {
  88. if len(nodes) == 0 {
  89. return
  90. }
  91. originalFn := t.extremityNodesFn
  92. t.extremityNodesFn = func(graph *Graph) []*Vertex {
  93. var want []string
  94. for _, node := range nodes {
  95. vertex := graph.Vertices[node]
  96. want = append(want, vertex.Service)
  97. for _, v := range getAncestors(vertex) {
  98. want = append(want, v.Service)
  99. }
  100. }
  101. t.ignored = map[string]struct{}{}
  102. for k := range graph.Vertices {
  103. if !utils.Contains(want, k) {
  104. t.ignored[k] = struct{}{}
  105. }
  106. }
  107. return originalFn(graph)
  108. }
  109. }
  110. }
  111. func (t *graphTraversal) visit(ctx context.Context, g *Graph) error {
  112. expect := len(g.Vertices)
  113. if expect == 0 {
  114. return nil
  115. }
  116. eg, ctx := errgroup.WithContext(ctx)
  117. if t.maxConcurrency > 0 {
  118. eg.SetLimit(t.maxConcurrency + 1)
  119. }
  120. nodeCh := make(chan *Vertex)
  121. eg.Go(func() error {
  122. for node := range nodeCh {
  123. expect--
  124. if expect == 0 {
  125. close(nodeCh)
  126. return nil
  127. }
  128. t.run(ctx, g, eg, t.adjacentNodesFn(node), nodeCh)
  129. }
  130. return nil
  131. })
  132. nodes := t.extremityNodesFn(g)
  133. t.run(ctx, g, eg, nodes, nodeCh)
  134. err := eg.Wait()
  135. return err
  136. }
  137. // Note: this could be `graph.walk` or whatever
  138. func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, nodeCh chan *Vertex) {
  139. for _, node := range nodes {
  140. // Don't start this service yet if all of its children have
  141. // not been started yet.
  142. if len(t.filterAdjacentByStatusFn(graph, node.Key, t.adjacentServiceStatusToSkip)) != 0 {
  143. continue
  144. }
  145. node := node
  146. if !t.consume(node.Key) {
  147. // another worker already visited this node
  148. continue
  149. }
  150. eg.Go(func() error {
  151. var err error
  152. if _, ignore := t.ignored[node.Service]; !ignore {
  153. err = t.visitorFn(ctx, node.Service)
  154. }
  155. if err == nil {
  156. graph.UpdateStatus(node.Key, t.targetServiceStatus)
  157. }
  158. nodeCh <- node
  159. return err
  160. })
  161. }
  162. }
  163. func (t *graphTraversal) consume(nodeKey string) bool {
  164. t.mu.Lock()
  165. defer t.mu.Unlock()
  166. if t.seen == nil {
  167. t.seen = make(map[string]struct{})
  168. }
  169. if _, ok := t.seen[nodeKey]; ok {
  170. return false
  171. }
  172. t.seen[nodeKey] = struct{}{}
  173. return true
  174. }
  175. // Graph represents project as service dependencies
  176. type Graph struct {
  177. Vertices map[string]*Vertex
  178. lock sync.RWMutex
  179. }
  180. // Vertex represents a service in the dependencies structure
  181. type Vertex struct {
  182. Key string
  183. Service string
  184. Status ServiceStatus
  185. Children map[string]*Vertex
  186. Parents map[string]*Vertex
  187. }
  188. func getParents(v *Vertex) []*Vertex {
  189. return v.GetParents()
  190. }
  191. // GetParents returns a slice with the parent vertices of the a Vertex
  192. func (v *Vertex) GetParents() []*Vertex {
  193. var res []*Vertex
  194. for _, p := range v.Parents {
  195. res = append(res, p)
  196. }
  197. return res
  198. }
  199. func getChildren(v *Vertex) []*Vertex {
  200. return v.GetChildren()
  201. }
  202. // getAncestors return all descendents for a vertex, might contain duplicates
  203. func getAncestors(v *Vertex) []*Vertex {
  204. var descendents []*Vertex
  205. for _, parent := range v.GetParents() {
  206. descendents = append(descendents, parent)
  207. descendents = append(descendents, getAncestors(parent)...)
  208. }
  209. return descendents
  210. }
  211. // GetChildren returns a slice with the child vertices of the a Vertex
  212. func (v *Vertex) GetChildren() []*Vertex {
  213. var res []*Vertex
  214. for _, p := range v.Children {
  215. res = append(res, p)
  216. }
  217. return res
  218. }
  219. // NewGraph returns the dependency graph of the services
  220. func NewGraph(services types.Services, initialStatus ServiceStatus) (*Graph, error) {
  221. graph := &Graph{
  222. lock: sync.RWMutex{},
  223. Vertices: map[string]*Vertex{},
  224. }
  225. for _, s := range services {
  226. graph.AddVertex(s.Name, s.Name, initialStatus)
  227. }
  228. for _, s := range services {
  229. for _, name := range s.GetDependencies() {
  230. _ = graph.AddEdge(s.Name, name)
  231. }
  232. }
  233. if b, err := graph.HasCycles(); b {
  234. return nil, err
  235. }
  236. return graph, nil
  237. }
  238. // NewVertex is the constructor function for the Vertex
  239. func NewVertex(key string, service string, initialStatus ServiceStatus) *Vertex {
  240. return &Vertex{
  241. Key: key,
  242. Service: service,
  243. Status: initialStatus,
  244. Parents: map[string]*Vertex{},
  245. Children: map[string]*Vertex{},
  246. }
  247. }
  248. // AddVertex adds a vertex to the Graph
  249. func (g *Graph) AddVertex(key string, service string, initialStatus ServiceStatus) {
  250. g.lock.Lock()
  251. defer g.lock.Unlock()
  252. v := NewVertex(key, service, initialStatus)
  253. g.Vertices[key] = v
  254. }
  255. // AddEdge adds a relationship of dependency between vertices `source` and `destination`
  256. func (g *Graph) AddEdge(source string, destination string) error {
  257. g.lock.Lock()
  258. defer g.lock.Unlock()
  259. sourceVertex := g.Vertices[source]
  260. destinationVertex := g.Vertices[destination]
  261. if sourceVertex == nil {
  262. return fmt.Errorf("could not find %s", source)
  263. }
  264. if destinationVertex == nil {
  265. return fmt.Errorf("could not find %s", destination)
  266. }
  267. // If they are already connected
  268. if _, ok := sourceVertex.Children[destination]; ok {
  269. return nil
  270. }
  271. sourceVertex.Children[destination] = destinationVertex
  272. destinationVertex.Parents[source] = sourceVertex
  273. return nil
  274. }
  275. func leaves(g *Graph) []*Vertex {
  276. return g.Leaves()
  277. }
  278. // Leaves returns the slice of leaves of the graph
  279. func (g *Graph) Leaves() []*Vertex {
  280. g.lock.Lock()
  281. defer g.lock.Unlock()
  282. var res []*Vertex
  283. for _, v := range g.Vertices {
  284. if len(v.Children) == 0 {
  285. res = append(res, v)
  286. }
  287. }
  288. return res
  289. }
  290. func roots(g *Graph) []*Vertex {
  291. return g.Roots()
  292. }
  293. // Roots returns the slice of "Roots" of the graph
  294. func (g *Graph) Roots() []*Vertex {
  295. g.lock.Lock()
  296. defer g.lock.Unlock()
  297. var res []*Vertex
  298. for _, v := range g.Vertices {
  299. if len(v.Parents) == 0 {
  300. res = append(res, v)
  301. }
  302. }
  303. return res
  304. }
  305. // UpdateStatus updates the status of a certain vertex
  306. func (g *Graph) UpdateStatus(key string, status ServiceStatus) {
  307. g.lock.Lock()
  308. defer g.lock.Unlock()
  309. g.Vertices[key].Status = status
  310. }
  311. func filterChildren(g *Graph, k string, s ServiceStatus) []*Vertex {
  312. return g.FilterChildren(k, s)
  313. }
  314. // FilterChildren returns children of a certain vertex that are in a certain status
  315. func (g *Graph) FilterChildren(key string, status ServiceStatus) []*Vertex {
  316. g.lock.Lock()
  317. defer g.lock.Unlock()
  318. var res []*Vertex
  319. vertex := g.Vertices[key]
  320. for _, child := range vertex.Children {
  321. if child.Status == status {
  322. res = append(res, child)
  323. }
  324. }
  325. return res
  326. }
  327. func filterParents(g *Graph, k string, s ServiceStatus) []*Vertex {
  328. return g.FilterParents(k, s)
  329. }
  330. // FilterParents returns the parents of a certain vertex that are in a certain status
  331. func (g *Graph) FilterParents(key string, status ServiceStatus) []*Vertex {
  332. g.lock.Lock()
  333. defer g.lock.Unlock()
  334. var res []*Vertex
  335. vertex := g.Vertices[key]
  336. for _, parent := range vertex.Parents {
  337. if parent.Status == status {
  338. res = append(res, parent)
  339. }
  340. }
  341. return res
  342. }
  343. // HasCycles detects cycles in the graph
  344. func (g *Graph) HasCycles() (bool, error) {
  345. discovered := []string{}
  346. finished := []string{}
  347. for _, vertex := range g.Vertices {
  348. path := []string{
  349. vertex.Key,
  350. }
  351. if !utils.StringContains(discovered, vertex.Key) && !utils.StringContains(finished, vertex.Key) {
  352. var err error
  353. discovered, finished, err = g.visit(vertex.Key, path, discovered, finished)
  354. if err != nil {
  355. return true, err
  356. }
  357. }
  358. }
  359. return false, nil
  360. }
  361. func (g *Graph) visit(key string, path []string, discovered []string, finished []string) ([]string, []string, error) {
  362. discovered = append(discovered, key)
  363. for _, v := range g.Vertices[key].Children {
  364. path := append(path, v.Key)
  365. if utils.StringContains(discovered, v.Key) {
  366. return nil, nil, fmt.Errorf("cycle found: %s", strings.Join(path, " -> "))
  367. }
  368. if !utils.StringContains(finished, v.Key) {
  369. if _, _, err := g.visit(v.Key, path, discovered, finished); err != nil {
  370. return nil, nil, err
  371. }
  372. }
  373. }
  374. discovered = remove(discovered, key)
  375. finished = append(finished, key)
  376. return discovered, finished, nil
  377. }
  378. func remove(slice []string, item string) []string {
  379. var s []string
  380. for _, i := range slice {
  381. if i != item {
  382. s = append(s, i)
  383. }
  384. }
  385. return s
  386. }