فهرست منبع

Implement device requests for GPU support

Signed-off-by: aiordache <[email protected]>
aiordache 5 سال پیش
والد
کامیت
854c003359
3فایلهای تغییر یافته به همراه50 افزوده شده و 2 حذف شده
  1. 19 1
      compose/config/compose_spec.json
  2. 27 1
      compose/project.py
  3. 4 0
      compose/service.py

+ 19 - 1
compose/config/compose_spec.json

@@ -524,7 +524,8 @@
               "properties": {
                 "cpus": {"type": ["number", "string"]},
                 "memory": {"type": "string"},
-                "generic_resources": {"$ref": "#/definitions/generic_resources"}
+                "generic_resources": {"$ref": "#/definitions/generic_resources"},
+                "devices": {"$ref": "#/definitions/devices"}
               },
               "additionalProperties": false,
               "patternProperties": {"^x-": {}}
@@ -590,6 +591,23 @@
       }
     },
 
+    "devices": {
+      "id": "#/definitions/devices",
+      "type": "array",
+      "items": {
+        "type": "object",
+        "properties": {
+            "capabilities": {"$ref": "#/definitions/list_of_strings"},
+            "count": {"type": ["string", "integer"]},
+            "device_ids": {"$ref": "#/definitions/list_of_strings"},
+            "driver":{"type": "string"},
+            "options":{"$ref": "#/definitions/list_or_dict"}
+          },
+        "additionalProperties": false,
+        "patternProperties": {"^x-": {}}
+      }
+    },
+
     "network": {
       "id": "#/definitions/network",
       "type": ["object", "null"],

+ 27 - 1
compose/project.py

@@ -128,7 +128,7 @@ class Project:
                 config_data.secrets)
 
             service_dict['scale'] = project.get_service_scale(service_dict)
-
+            device_requests = project.get_device_requests(service_dict)
             service_dict = translate_credential_spec_to_security_opt(service_dict)
             service_dict, ignored_keys = translate_deploy_keys_to_container_config(
                 service_dict
@@ -154,6 +154,7 @@ class Project:
                     ipc_mode=ipc_mode,
                     platform=service_dict.pop('platform', None),
                     default_platform=default_platform,
+                    device_requests=device_requests,
                     extra_labels=extra_labels,
                     **service_dict)
             )
@@ -331,6 +332,31 @@ class Project:
                 max_replicas))
         return scale
 
+    def get_device_requests(self, service_dict):
+        deploy_dict = service_dict.get('deploy', None)
+        if not deploy_dict:
+            return
+
+        resources = deploy_dict.get('resources', None)
+        if not resources or not resources.get('reservations', None):
+            return
+        devices = resources['reservations'].get('devices')
+        if not devices:
+            return
+
+        for dev in devices:
+            count = dev.get("count", -1)
+            if not isinstance(count, int):
+                if count != "all":
+                    raise ConfigurationError(
+                        'Invalid value "{}" for devices count'.format(dev["count"]),
+                        '(expected integer or "all")')
+                dev["count"] = -1
+
+            if 'capabilities' in dev:
+                dev['capabilities'] = [dev['capabilities']]
+        return devices
+
     def start(self, service_names=None, **options):
         containers = []
 

+ 4 - 0
compose/service.py

@@ -77,6 +77,7 @@ HOST_CONFIG_KEYS = [
     'cpuset',
     'device_cgroup_rules',
     'devices',
+    'device_requests',
     'dns',
     'dns_search',
     'dns_opt',
@@ -180,6 +181,7 @@ class Service:
             pid_mode=None,
             default_platform=None,
             extra_labels=None,
+            device_requests=None,
             **options
     ):
         self.name = name
@@ -195,6 +197,7 @@ class Service:
         self.secrets = secrets or []
         self.scale_num = scale
         self.default_platform = default_platform
+        self.device_requests = device_requests
         self.options = options
         self.extra_labels = extra_labels or []
 
@@ -1016,6 +1019,7 @@ class Service:
             privileged=options.get('privileged', False),
             network_mode=self.network_mode.mode,
             devices=options.get('devices'),
+            device_requests=self.device_requests,
             dns=options.get('dns'),
             dns_opt=options.get('dns_opt'),
             dns_search=options.get('dns_search'),