From 6b0b54b35bb4039a59d3ac53c58b998800072726 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 23 Nov 2014 10:42:39 -0500 Subject: [PATCH] - ensure we include for dependencies when we do a stamp, add an option to filter_for_lineage --- alembic/revision.py | 13 ++++++++----- alembic/script.py | 2 +- tests/test_version_traversal.py | 7 +++++++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/alembic/revision.py b/alembic/revision.py index 237cf79..1afb203 100644 --- a/alembic/revision.py +++ b/alembic/revision.py @@ -341,7 +341,8 @@ class RevisionMap(object): (revision.revision, check_branch)) return revision - def filter_for_lineage(self, targets, check_against): + def filter_for_lineage( + self, targets, check_against, include_dependencies=False): id_, branch_label = self._resolve_revision_number(check_against) shares = [] @@ -354,9 +355,11 @@ class RevisionMap(object): return [ tg for tg in targets - if self._shares_lineage(tg, shares)] + if self._shares_lineage( + tg, shares, include_dependencies=include_dependencies)] - def _shares_lineage(self, target, test_against_revs): + def _shares_lineage( + self, target, test_against_revs, include_dependencies=False): if not test_against_revs: return True if not isinstance(target, Revision): @@ -372,9 +375,9 @@ class RevisionMap(object): return bool( set(self._get_descendant_nodes([target], - include_dependencies=False)) + include_dependencies=include_dependencies)) .union(self._get_ancestor_nodes([target], - include_dependencies=False)) + include_dependencies=include_dependencies)) .intersection(test_against_revs) ) diff --git a/alembic/script.py b/alembic/script.py index 79fe01e..763d048 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -329,7 +329,7 @@ class ScriptDirectory(object): # filter for lineage will resolve things like # branchname@base, version@base, etc. filtered_heads = self.revision_map.filter_for_lineage( - heads, revision) + heads, revision, include_dependencies=True) dest = self.get_revision(revision) diff --git a/tests/test_version_traversal.py b/tests/test_version_traversal.py index a503a52..1fc99e2 100644 --- a/tests/test_version_traversal.py +++ b/tests/test_version_traversal.py @@ -457,6 +457,13 @@ class DependsOnBranchTestOne(MigrationTest): head.update_to_step(self.down_(self.d1)) eq_(head.heads, set([self.c2.revision])) + def test_stamp_across_dependency(self): + heads = [self.e1.revision, self.c2.revision] + head = HeadMaintainer(mock.Mock(), heads) + for step in self.env._stamp_revs(self.b1.revision, heads): + head.update_to_step(step) + eq_(head.heads, set([self.b1.revision])) + class DependsOnBranchTestTwo(MigrationTest): @classmethod