gpu.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /*
  2. Copyright 2020 Docker, Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package ecs
  14. import (
  15. "fmt"
  16. "math"
  17. "strconv"
  18. "github.com/compose-spec/compose-go/types"
  19. "github.com/docker/go-units"
  20. )
  21. type machine struct {
  22. id string
  23. cpus float64
  24. memory types.UnitBytes
  25. gpus int64
  26. }
  27. type family []machine
  28. var p3family = family{
  29. {
  30. id: "p3.2xlarge",
  31. cpus: 8,
  32. memory: 64 * units.GiB,
  33. gpus: 2,
  34. },
  35. {
  36. id: "p3.8xlarge",
  37. cpus: 32,
  38. memory: 244 * units.GiB,
  39. gpus: 4,
  40. },
  41. {
  42. id: "p3.16xlarge",
  43. cpus: 64,
  44. memory: 488 * units.GiB,
  45. gpus: 8,
  46. },
  47. }
  48. type filterFn func(machine) bool
  49. func (f family) filter(fn filterFn) family {
  50. var filtered family
  51. for _, machine := range f {
  52. if fn(machine) {
  53. filtered = append(filtered, machine)
  54. }
  55. }
  56. return filtered
  57. }
  58. func (f family) firstOrError(msg string, args ...interface{}) (machine, error) {
  59. if len(f) == 0 {
  60. return machine{}, fmt.Errorf(msg, args...)
  61. }
  62. return f[0], nil
  63. }
  64. func guessMachineType(project *types.Project) (string, error) {
  65. // we select a machine type to match all gpus-bound services requirements
  66. // once https://github.com/aws/containers-roadmap/issues/631 is implemented we can define dedicated CapacityProviders per service.
  67. requirements, err := getResourceRequirements(project)
  68. if err != nil {
  69. return "", err
  70. }
  71. instanceType, err := p3family.
  72. filter(func(m machine) bool {
  73. return m.memory >= requirements.memory
  74. }).
  75. filter(func(m machine) bool {
  76. return m.cpus >= requirements.cpus
  77. }).
  78. filter(func(m machine) bool {
  79. return m.gpus >= requirements.gpus
  80. }).
  81. firstOrError("none of the AWS p3 machines match requirement for memory:%d cpu:%f gpus:%d", requirements.memory, requirements.cpus, requirements.gpus)
  82. if err != nil {
  83. return "", err
  84. }
  85. return instanceType.id, nil
  86. }
  87. type resourceRequirements struct {
  88. memory types.UnitBytes
  89. cpus float64
  90. gpus int64
  91. }
  92. func getResourceRequirements(project *types.Project) (*resourceRequirements, error) {
  93. return toResourceRequirementsSlice(project).
  94. filter(func(requirements *resourceRequirements) bool {
  95. return requirements.gpus != 0
  96. }).
  97. max()
  98. }
  99. type eitherRequirementsOrError struct {
  100. requirements []*resourceRequirements
  101. err error
  102. }
  103. func toResourceRequirementsSlice(project *types.Project) eitherRequirementsOrError {
  104. var requirements []*resourceRequirements
  105. for _, service := range project.Services {
  106. r, err := toResourceRequirements(service)
  107. if err != nil {
  108. return eitherRequirementsOrError{nil, err}
  109. }
  110. requirements = append(requirements, r)
  111. }
  112. return eitherRequirementsOrError{requirements, nil}
  113. }
  114. func (r eitherRequirementsOrError) filter(fn func(*resourceRequirements) bool) eitherRequirementsOrError {
  115. if r.err != nil {
  116. return r
  117. }
  118. var requirements []*resourceRequirements
  119. for _, req := range r.requirements {
  120. if fn(req) {
  121. requirements = append(requirements, req)
  122. }
  123. }
  124. return eitherRequirementsOrError{requirements, nil}
  125. }
  126. func toResourceRequirements(service types.ServiceConfig) (*resourceRequirements, error) {
  127. if service.Deploy == nil {
  128. return nil, nil
  129. }
  130. reservations := service.Deploy.Resources.Reservations
  131. if reservations == nil {
  132. return nil, nil
  133. }
  134. var requiredGPUs int64
  135. for _, r := range reservations.GenericResources {
  136. if r.DiscreteResourceSpec.Kind == "gpus" {
  137. requiredGPUs = r.DiscreteResourceSpec.Value
  138. break
  139. }
  140. }
  141. var nanocpu float64
  142. if reservations.NanoCPUs != "" {
  143. v, err := strconv.ParseFloat(reservations.NanoCPUs, 64)
  144. if err != nil {
  145. return nil, err
  146. }
  147. nanocpu = v
  148. }
  149. return &resourceRequirements{
  150. memory: reservations.MemoryBytes,
  151. cpus: nanocpu,
  152. gpus: requiredGPUs,
  153. }, nil
  154. }
  155. func (r resourceRequirements) combine(o *resourceRequirements) resourceRequirements {
  156. if o == nil {
  157. return r
  158. }
  159. return resourceRequirements{
  160. memory: maxUnitBytes(r.memory, o.memory),
  161. cpus: math.Max(r.cpus, o.cpus),
  162. gpus: maxInt64(r.gpus, o.gpus),
  163. }
  164. }
  165. func (r eitherRequirementsOrError) max() (*resourceRequirements, error) {
  166. if r.err != nil {
  167. return nil, r.err
  168. }
  169. min := resourceRequirements{}
  170. for _, req := range r.requirements {
  171. min = min.combine(req)
  172. }
  173. return &min, nil
  174. }
  175. func maxInt64(a, b int64) int64 {
  176. if a > b {
  177. return a
  178. }
  179. return b
  180. }
  181. func maxUnitBytes(a, b types.UnitBytes) types.UnitBytes {
  182. if a > b {
  183. return a
  184. }
  185. return b
  186. }