Преглед изворни кода

Merge pull request #456 from dnephin/volumes_from_service

Fix volumes_from a service with no containers
Aanand Prasad пре 11 година
родитељ
комит
267be12bb2
2 измењених фајлова са 51 додато и 6 уклоњено
  1. 12 6
      fig/service.py
  2. 39 0
      tests/unit/service_test.py

+ 12 - 6
fig/service.py

@@ -4,6 +4,7 @@ from collections import namedtuple
 import logging
 import re
 import os
+from operator import attrgetter
 import sys
 
 from docker.errors import APIError
@@ -308,12 +309,17 @@ class Service(object):
 
     def _get_volumes_from(self, intermediate_container=None):
         volumes_from = []
-        for v in self.volumes_from:
-            if isinstance(v, Service):
-                for container in v.containers(stopped=True):
-                    volumes_from.append(container.id)
-            elif isinstance(v, Container):
-                volumes_from.append(v.id)
+        for volume_source in self.volumes_from:
+            if isinstance(volume_source, Service):
+                containers = volume_source.containers(stopped=True)
+
+                if not containers:
+                    volumes_from.append(volume_source.create_container().id)
+                else:
+                    volumes_from.extend(map(attrgetter('id'), containers))
+
+            elif isinstance(volume_source, Container):
+                volumes_from.append(volume_source.id)
 
         if intermediate_container:
             volumes_from.append(intermediate_container.id)

+ 39 - 0
tests/unit/service_test.py

@@ -8,6 +8,7 @@ import mock
 import docker
 
 from fig import Service
+from fig.container import Container
 from fig.service import (
     ConfigError,
     split_port,
@@ -44,6 +45,44 @@ class ServiceTest(unittest.TestCase):
         self.assertRaises(ConfigError, lambda: Service(name='foo', port=['8000']))
         Service(name='foo', ports=['8000'])
 
+    def test_get_volumes_from_container(self):
+        container_id = 'aabbccddee'
+        service = Service(
+            'test',
+            volumes_from=[mock.Mock(id=container_id, spec=Container)])
+
+        self.assertEqual(service._get_volumes_from(), [container_id])
+
+    def test_get_volumes_from_intermediate_container(self):
+        container_id = 'aabbccddee'
+        service = Service('test')
+        container = mock.Mock(id=container_id, spec=Container)
+
+        self.assertEqual(service._get_volumes_from(container), [container_id])
+
+    def test_get_volumes_from_service_container_exists(self):
+        container_ids = ['aabbccddee', '12345']
+        from_service = mock.create_autospec(Service)
+        from_service.containers.return_value = [
+            mock.Mock(id=container_id, spec=Container)
+            for container_id in container_ids
+        ]
+        service = Service('test', volumes_from=[from_service])
+
+        self.assertEqual(service._get_volumes_from(), container_ids)
+
+    def test_get_volumes_from_service_no_container(self):
+        container_id = 'abababab'
+        from_service = mock.create_autospec(Service)
+        from_service.containers.return_value = []
+        from_service.create_container.return_value = mock.Mock(
+            id=container_id,
+            spec=Container)
+        service = Service('test', volumes_from=[from_service])
+
+        self.assertEqual(service._get_volumes_from(), [container_id])
+        from_service.create_container.assert_called_once_with()
+
     def test_split_port_with_host_ip(self):
         internal_port, external_port = split_port("127.0.0.1:1000:2000")
         self.assertEqual(internal_port, "2000")