dependencies.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. /*
  2. Copyright 2020 Docker Compose CLI authors
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package compose
  14. import (
  15. "context"
  16. "fmt"
  17. "strings"
  18. "sync"
  19. "github.com/compose-spec/compose-go/types"
  20. "github.com/docker/compose/v2/pkg/api"
  21. "github.com/pkg/errors"
  22. "golang.org/x/sync/errgroup"
  23. "github.com/docker/compose/v2/pkg/utils"
  24. )
  25. // ServiceStatus indicates the status of a service
  26. type ServiceStatus int
  27. // Services status flags
  28. const (
  29. ServiceStopped ServiceStatus = iota
  30. ServiceStarted
  31. )
  32. type graphTraversal struct {
  33. mu sync.Mutex
  34. seen map[string]struct{}
  35. ignored map[string]struct{}
  36. extremityNodesFn func(*Graph) []*Vertex // leaves or roots
  37. adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren
  38. filterAdjacentByStatusFn func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents
  39. targetServiceStatus ServiceStatus
  40. adjacentServiceStatusToSkip ServiceStatus
  41. visitorFn func(context.Context, string) error
  42. maxConcurrency int
  43. }
  44. func upDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
  45. return &graphTraversal{
  46. extremityNodesFn: leaves,
  47. adjacentNodesFn: getParents,
  48. filterAdjacentByStatusFn: filterChildren,
  49. adjacentServiceStatusToSkip: ServiceStopped,
  50. targetServiceStatus: ServiceStarted,
  51. visitorFn: visitorFn,
  52. }
  53. }
  54. func downDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
  55. return &graphTraversal{
  56. extremityNodesFn: roots,
  57. adjacentNodesFn: getChildren,
  58. filterAdjacentByStatusFn: filterParents,
  59. adjacentServiceStatusToSkip: ServiceStarted,
  60. targetServiceStatus: ServiceStopped,
  61. visitorFn: visitorFn,
  62. }
  63. }
  64. // InDependencyOrder applies the function to the services of the project taking in account the dependency order
  65. func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
  66. graph, err := NewGraph(project, ServiceStopped)
  67. if err != nil {
  68. return err
  69. }
  70. t := upDirectionTraversal(fn)
  71. for _, option := range options {
  72. option(t)
  73. }
  74. return t.visit(ctx, graph)
  75. }
  76. // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
  77. func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
  78. graph, err := NewGraph(project, ServiceStarted)
  79. if err != nil {
  80. return err
  81. }
  82. t := downDirectionTraversal(fn)
  83. for _, option := range options {
  84. option(t)
  85. }
  86. return t.visit(ctx, graph)
  87. }
  88. func WithRootNodesAndDown(nodes []string) func(*graphTraversal) {
  89. return func(t *graphTraversal) {
  90. if len(nodes) == 0 {
  91. return
  92. }
  93. originalFn := t.extremityNodesFn
  94. t.extremityNodesFn = func(graph *Graph) []*Vertex {
  95. var want []string
  96. for _, node := range nodes {
  97. vertex := graph.Vertices[node]
  98. want = append(want, vertex.Service)
  99. for _, v := range getAncestors(vertex) {
  100. want = append(want, v.Service)
  101. }
  102. }
  103. t.ignored = map[string]struct{}{}
  104. for k := range graph.Vertices {
  105. if !utils.Contains(want, k) {
  106. t.ignored[k] = struct{}{}
  107. }
  108. }
  109. return originalFn(graph)
  110. }
  111. }
  112. }
  113. func (t *graphTraversal) visit(ctx context.Context, g *Graph) error {
  114. expect := len(g.Vertices)
  115. if expect == 0 {
  116. return nil
  117. }
  118. eg, ctx := errgroup.WithContext(ctx)
  119. if t.maxConcurrency > 0 {
  120. eg.SetLimit(t.maxConcurrency + 1)
  121. }
  122. nodeCh := make(chan *Vertex)
  123. eg.Go(func() error {
  124. for node := range nodeCh {
  125. expect--
  126. if expect == 0 {
  127. close(nodeCh)
  128. return nil
  129. }
  130. t.run(ctx, g, eg, t.adjacentNodesFn(node), nodeCh)
  131. }
  132. return nil
  133. })
  134. nodes := t.extremityNodesFn(g)
  135. t.run(ctx, g, eg, nodes, nodeCh)
  136. err := eg.Wait()
  137. return err
  138. }
  139. // Note: this could be `graph.walk` or whatever
  140. func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, nodeCh chan *Vertex) {
  141. for _, node := range nodes {
  142. // Don't start this service yet if all of its children have
  143. // not been started yet.
  144. if len(t.filterAdjacentByStatusFn(graph, node.Key, t.adjacentServiceStatusToSkip)) != 0 {
  145. continue
  146. }
  147. node := node
  148. if !t.consume(node.Key) {
  149. // another worker already visited this node
  150. continue
  151. }
  152. eg.Go(func() error {
  153. var err error
  154. if _, ignore := t.ignored[node.Service]; !ignore {
  155. err = t.visitorFn(ctx, node.Service)
  156. }
  157. if err == nil {
  158. graph.UpdateStatus(node.Key, t.targetServiceStatus)
  159. }
  160. nodeCh <- node
  161. return err
  162. })
  163. }
  164. }
  165. func (t *graphTraversal) consume(nodeKey string) bool {
  166. t.mu.Lock()
  167. defer t.mu.Unlock()
  168. if t.seen == nil {
  169. t.seen = make(map[string]struct{})
  170. }
  171. if _, ok := t.seen[nodeKey]; ok {
  172. return false
  173. }
  174. t.seen[nodeKey] = struct{}{}
  175. return true
  176. }
  177. // Graph represents project as service dependencies
  178. type Graph struct {
  179. Vertices map[string]*Vertex
  180. lock sync.RWMutex
  181. }
  182. // Vertex represents a service in the dependencies structure
  183. type Vertex struct {
  184. Key string
  185. Service string
  186. Status ServiceStatus
  187. Children map[string]*Vertex
  188. Parents map[string]*Vertex
  189. }
  190. func getParents(v *Vertex) []*Vertex {
  191. return v.GetParents()
  192. }
  193. // GetParents returns a slice with the parent vertices of the a Vertex
  194. func (v *Vertex) GetParents() []*Vertex {
  195. var res []*Vertex
  196. for _, p := range v.Parents {
  197. res = append(res, p)
  198. }
  199. return res
  200. }
  201. func getChildren(v *Vertex) []*Vertex {
  202. return v.GetChildren()
  203. }
  204. // getAncestors return all descendents for a vertex, might contain duplicates
  205. func getAncestors(v *Vertex) []*Vertex {
  206. var descendents []*Vertex
  207. for _, parent := range v.GetParents() {
  208. descendents = append(descendents, parent)
  209. descendents = append(descendents, getAncestors(parent)...)
  210. }
  211. return descendents
  212. }
  213. // GetChildren returns a slice with the child vertices of the a Vertex
  214. func (v *Vertex) GetChildren() []*Vertex {
  215. var res []*Vertex
  216. for _, p := range v.Children {
  217. res = append(res, p)
  218. }
  219. return res
  220. }
  221. // NewGraph returns the dependency graph of the services
  222. func NewGraph(project *types.Project, initialStatus ServiceStatus) (*Graph, error) {
  223. graph := &Graph{
  224. lock: sync.RWMutex{},
  225. Vertices: map[string]*Vertex{},
  226. }
  227. for _, s := range project.Services {
  228. graph.AddVertex(s.Name, s.Name, initialStatus)
  229. }
  230. for _, s := range project.Services {
  231. for _, name := range s.GetDependencies() {
  232. err := graph.AddEdge(s.Name, name)
  233. if err != nil {
  234. if api.IsNotFoundError(err) {
  235. ds, err := project.GetDisabledService(name)
  236. if err == nil {
  237. return nil, fmt.Errorf("service %s is required by %s but is disabled. Can be enabled by profiles %s", name, s.Name, ds.Profiles)
  238. }
  239. }
  240. return nil, err
  241. }
  242. }
  243. }
  244. if b, err := graph.HasCycles(); b {
  245. return nil, err
  246. }
  247. return graph, nil
  248. }
  249. // NewVertex is the constructor function for the Vertex
  250. func NewVertex(key string, service string, initialStatus ServiceStatus) *Vertex {
  251. return &Vertex{
  252. Key: key,
  253. Service: service,
  254. Status: initialStatus,
  255. Parents: map[string]*Vertex{},
  256. Children: map[string]*Vertex{},
  257. }
  258. }
  259. // AddVertex adds a vertex to the Graph
  260. func (g *Graph) AddVertex(key string, service string, initialStatus ServiceStatus) {
  261. g.lock.Lock()
  262. defer g.lock.Unlock()
  263. v := NewVertex(key, service, initialStatus)
  264. g.Vertices[key] = v
  265. }
  266. // AddEdge adds a relationship of dependency between vertices `source` and `destination`
  267. func (g *Graph) AddEdge(source string, destination string) error {
  268. g.lock.Lock()
  269. defer g.lock.Unlock()
  270. sourceVertex := g.Vertices[source]
  271. destinationVertex := g.Vertices[destination]
  272. if sourceVertex == nil {
  273. return errors.Wrapf(api.ErrNotFound, "could not find %s", source)
  274. }
  275. if destinationVertex == nil {
  276. return errors.Wrapf(api.ErrNotFound, "could not find %s", destination)
  277. }
  278. // If they are already connected
  279. if _, ok := sourceVertex.Children[destination]; ok {
  280. return nil
  281. }
  282. sourceVertex.Children[destination] = destinationVertex
  283. destinationVertex.Parents[source] = sourceVertex
  284. return nil
  285. }
  286. func leaves(g *Graph) []*Vertex {
  287. return g.Leaves()
  288. }
  289. // Leaves returns the slice of leaves of the graph
  290. func (g *Graph) Leaves() []*Vertex {
  291. g.lock.Lock()
  292. defer g.lock.Unlock()
  293. var res []*Vertex
  294. for _, v := range g.Vertices {
  295. if len(v.Children) == 0 {
  296. res = append(res, v)
  297. }
  298. }
  299. return res
  300. }
  301. func roots(g *Graph) []*Vertex {
  302. return g.Roots()
  303. }
  304. // Roots returns the slice of "Roots" of the graph
  305. func (g *Graph) Roots() []*Vertex {
  306. g.lock.Lock()
  307. defer g.lock.Unlock()
  308. var res []*Vertex
  309. for _, v := range g.Vertices {
  310. if len(v.Parents) == 0 {
  311. res = append(res, v)
  312. }
  313. }
  314. return res
  315. }
  316. // UpdateStatus updates the status of a certain vertex
  317. func (g *Graph) UpdateStatus(key string, status ServiceStatus) {
  318. g.lock.Lock()
  319. defer g.lock.Unlock()
  320. g.Vertices[key].Status = status
  321. }
  322. func filterChildren(g *Graph, k string, s ServiceStatus) []*Vertex {
  323. return g.FilterChildren(k, s)
  324. }
  325. // FilterChildren returns children of a certain vertex that are in a certain status
  326. func (g *Graph) FilterChildren(key string, status ServiceStatus) []*Vertex {
  327. g.lock.Lock()
  328. defer g.lock.Unlock()
  329. var res []*Vertex
  330. vertex := g.Vertices[key]
  331. for _, child := range vertex.Children {
  332. if child.Status == status {
  333. res = append(res, child)
  334. }
  335. }
  336. return res
  337. }
  338. func filterParents(g *Graph, k string, s ServiceStatus) []*Vertex {
  339. return g.FilterParents(k, s)
  340. }
  341. // FilterParents returns the parents of a certain vertex that are in a certain status
  342. func (g *Graph) FilterParents(key string, status ServiceStatus) []*Vertex {
  343. g.lock.Lock()
  344. defer g.lock.Unlock()
  345. var res []*Vertex
  346. vertex := g.Vertices[key]
  347. for _, parent := range vertex.Parents {
  348. if parent.Status == status {
  349. res = append(res, parent)
  350. }
  351. }
  352. return res
  353. }
  354. // HasCycles detects cycles in the graph
  355. func (g *Graph) HasCycles() (bool, error) {
  356. discovered := []string{}
  357. finished := []string{}
  358. for _, vertex := range g.Vertices {
  359. path := []string{
  360. vertex.Key,
  361. }
  362. if !utils.StringContains(discovered, vertex.Key) && !utils.StringContains(finished, vertex.Key) {
  363. var err error
  364. discovered, finished, err = g.visit(vertex.Key, path, discovered, finished)
  365. if err != nil {
  366. return true, err
  367. }
  368. }
  369. }
  370. return false, nil
  371. }
  372. func (g *Graph) visit(key string, path []string, discovered []string, finished []string) ([]string, []string, error) {
  373. discovered = append(discovered, key)
  374. for _, v := range g.Vertices[key].Children {
  375. path := append(path, v.Key)
  376. if utils.StringContains(discovered, v.Key) {
  377. return nil, nil, fmt.Errorf("cycle found: %s", strings.Join(path, " -> "))
  378. }
  379. if !utils.StringContains(finished, v.Key) {
  380. if _, _, err := g.visit(v.Key, path, discovered, finished); err != nil {
  381. return nil, nil, err
  382. }
  383. }
  384. }
  385. discovered = remove(discovered, key)
  386. finished = append(finished, key)
  387. return discovered, finished, nil
  388. }
  389. func remove(slice []string, item string) []string {
  390. var s []string
  391. for _, i := range slice {
  392. if i != item {
  393. s = append(s, i)
  394. }
  395. }
  396. return s
  397. }