Bläddra i källkod

Merge pull request #52 from cameronmaske/fix-sort-service-dicts

Fix for #48
Ben Firshman 11 år sedan
förälder
incheckning
8a2bbeb1eb
5 ändrade filer med 139 tillägg och 39 borttagningar
  1. 2 5
      fig/cli/main.py
  2. 0 0
      fig/compat/__init__.py
  3. 0 23
      fig/compat/functools.py
  4. 34 11
      fig/project.py
  5. 103 0
      tests/sort_service_test.py

+ 2 - 5
fig/cli/main.py

@@ -8,7 +8,7 @@ import signal
 from inspect import getdoc
 
 from .. import __version__
-from ..project import NoSuchService
+from ..project import NoSuchService, DependencyError
 from ..service import CannotBeScaledError
 from .command import Command
 from .formatter import Formatter
@@ -40,10 +40,7 @@ def main():
     except KeyboardInterrupt:
         log.error("\nAborting.")
         exit(1)
-    except UserError as e:
-        log.error(e.msg)
-        exit(1)
-    except NoSuchService as e:
+    except (UserError, NoSuchService, DependencyError) as e:
         log.error(e.msg)
         exit(1)
     except NoSuchCommand as e:

+ 0 - 0
fig/compat/__init__.py


+ 0 - 23
fig/compat/functools.py

@@ -1,23 +0,0 @@
-
-# Taken from python2.7/3.3 functools
-def cmp_to_key(mycmp):
-    """Convert a cmp= function into a key= function"""
-    class K(object):
-        __slots__ = ['obj']
-        def __init__(self, obj):
-            self.obj = obj
-        def __lt__(self, other):
-            return mycmp(self.obj, other.obj) < 0
-        def __gt__(self, other):
-            return mycmp(self.obj, other.obj) > 0
-        def __eq__(self, other):
-            return mycmp(self.obj, other.obj) == 0
-        def __le__(self, other):
-            return mycmp(self.obj, other.obj) <= 0
-        def __ge__(self, other):
-            return mycmp(self.obj, other.obj) >= 0
-        def __ne__(self, other):
-            return mycmp(self.obj, other.obj) != 0
-        __hash__ = None
-    return K
-

+ 34 - 11
fig/project.py

@@ -2,21 +2,36 @@ from __future__ import unicode_literals
 from __future__ import absolute_import
 import logging
 from .service import Service
-from .compat.functools import cmp_to_key
 
 log = logging.getLogger(__name__)
 
+
 def sort_service_dicts(services):
-    # Sort in dependency order
-    def cmp(x, y):
-        x_deps_y = y['name'] in x.get('links', [])
-        y_deps_x = x['name'] in y.get('links', [])
-        if x_deps_y and not y_deps_x:
-            return 1
-        elif y_deps_x and not x_deps_y:
-            return -1
-        return 0
-    return sorted(services, key=cmp_to_key(cmp))
+    # Get all services that are dependant on another.
+    dependent_services = [s for s in services if s.get('links')]
+    flatten_links = sum([s['links'] for s in dependent_services], [])
+    # Get all services that are not linked to and don't link to others.
+    non_dependent_sevices = [s for s in services if s['name'] not in flatten_links and not s.get('links')]
+    sorted_services = []
+    # Topological sort.
+    while dependent_services:
+        n = dependent_services.pop()
+        # Check if a service is dependent on itself, if so raise an error.
+        if n['name'] in n.get('links', []):
+            raise DependencyError('A service can not link to itself: %s' % n['name'])
+        sorted_services.append(n)
+        for l in n['links']:
+            # Get the linked service.
+            linked_service = next(s for s in services if l == s['name'])
+            # Check that there isn't a circular import between services.
+            if n['name'] in linked_service.get('links', []):
+                raise DependencyError('Circular import between %s and %s' % (n['name'], linked_service['name']))
+            # Check the linked service has no links and is not already in the
+            # sorted service list.
+            if not linked_service.get('links') and linked_service not in sorted_services:
+                sorted_services.insert(0, linked_service)
+    return non_dependent_sevices + sorted_services
+
 
 class Project(object):
     """
@@ -134,3 +149,11 @@ class NoSuchService(Exception):
 
     def __str__(self):
         return self.msg
+
+
+class DependencyError(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return self.msg

+ 103 - 0
tests/sort_service_test.py

@@ -0,0 +1,103 @@
+from fig.project import sort_service_dicts, DependencyError
+from . import unittest
+
+
+class SortServiceTest(unittest.TestCase):
+    def test_sort_service_dicts_1(self):
+        services = [
+            {
+                'links': ['redis'],
+                'name': 'web'
+            },
+            {
+                'name': 'grunt'
+            },
+            {
+                'name': 'redis'
+            }
+        ]
+
+        sorted_services = sort_service_dicts(services)
+        self.assertEqual(len(sorted_services), 3)
+        self.assertEqual(sorted_services[0]['name'], 'grunt')
+        self.assertEqual(sorted_services[1]['name'], 'redis')
+        self.assertEqual(sorted_services[2]['name'], 'web')
+
+    def test_sort_service_dicts_2(self):
+        services = [
+            {
+                'links': ['redis', 'postgres'],
+                'name': 'web'
+            },
+            {
+                'name': 'postgres',
+                'links': ['redis']
+            },
+            {
+                'name': 'redis'
+            }
+        ]
+
+        sorted_services = sort_service_dicts(services)
+        self.assertEqual(len(sorted_services), 3)
+        self.assertEqual(sorted_services[0]['name'], 'redis')
+        self.assertEqual(sorted_services[1]['name'], 'postgres')
+        self.assertEqual(sorted_services[2]['name'], 'web')
+
+    def test_sort_service_dicts_circular_imports(self):
+        services = [
+            {
+                'links': ['redis'],
+                'name': 'web'
+            },
+            {
+                'name': 'redis',
+                'links': ['web']
+            },
+        ]
+
+        try:
+            sort_service_dicts(services)
+        except DependencyError as e:
+            self.assertIn('redis', e.msg)
+            self.assertIn('web', e.msg)
+        else:
+            self.fail('Should have thrown an DependencyError')
+
+    def test_sort_service_dicts_circular_imports_2(self):
+        services = [
+            {
+                'links': ['postgres', 'redis'],
+                'name': 'web'
+            },
+            {
+                'name': 'redis',
+                'links': ['web']
+            },
+            {
+                'name': 'postgres'
+            }
+        ]
+
+        try:
+            sort_service_dicts(services)
+        except DependencyError as e:
+            self.assertIn('redis', e.msg)
+            self.assertIn('web', e.msg)
+        else:
+            self.fail('Should have thrown an DependencyError')
+
+    def test_sort_service_dicts_self_imports(self):
+        services = [
+            {
+                'links': ['web'],
+                'name': 'web'
+            },
+        ]
+
+        try:
+            sort_service_dicts(services)
+        except DependencyError as e:
+            self.assertIn('web', e.msg)
+        else:
+            self.fail('Should have thrown an DependencyError')