Added preliminary support for composite key batch fetching
This commit is contained in:
		| @@ -201,7 +201,10 @@ class CompositeFetcher(object): | ||||
|     def fetch(self): | ||||
|         for entity in self.related_entities: | ||||
|             for fetcher in self.fetchers: | ||||
|                 if getattr(entity, fetcher.remote_column_name) is not None: | ||||
|                 if any( | ||||
|                     getattr(entity, name) | ||||
|                     for name in fetcher.remote_column_names | ||||
|                 ): | ||||
|                     fetcher.append_entity(entity) | ||||
|  | ||||
|     def populate(self): | ||||
| @@ -227,29 +230,46 @@ class Fetcher(object): | ||||
|         return self.path.session.query(self.path.model).filter(self.condition) | ||||
|  | ||||
|     @property | ||||
|     def remote_column_name(self): | ||||
|         return list(self.path.property.remote_side)[0].name | ||||
|     def local_column_names(self): | ||||
|         names = [] | ||||
|         for local, remote in self.prop.local_remote_pairs: | ||||
|             for fk in remote.foreign_keys: | ||||
|                 # TODO: make this support inherited tables | ||||
|                 if fk.column.table in self.prop.parent.tables: | ||||
|                     names.append(local.name) | ||||
|         return names | ||||
|  | ||||
|     def parent_key(self, entity): | ||||
|         return tuple( | ||||
|             getattr(entity, name) | ||||
|             for name in self.remote_column_names | ||||
|         ) | ||||
|  | ||||
|     def local_values(self, entity): | ||||
|         return getattr(entity, list(self.prop.local_columns)[0].name) | ||||
|         return tuple( | ||||
|             getattr(entity, name) | ||||
|             for name in self.local_column_names | ||||
|         ) | ||||
|  | ||||
|     def populate_backrefs(self, related_entities): | ||||
|         """ | ||||
|         Populates backrefs for given related entities. | ||||
|         """ | ||||
|         backref_dict = dict( | ||||
|             (self.local_values(entity), []) | ||||
|             for entity, parent_id in related_entities | ||||
|             (self.local_values(value[0]), []) | ||||
|             for value in related_entities | ||||
|         ) | ||||
|         for entity, parent_id in related_entities: | ||||
|             backref_dict[self.local_values(entity)].append( | ||||
|                 self.path.session.query(self.path.parent_model).get(parent_id) | ||||
|         for value in related_entities: | ||||
|             backref_dict[self.local_values(value[0])].append( | ||||
|                 self.path.session.query(self.path.parent_model).get( | ||||
|                     tuple(value[1:]) | ||||
|                 ) | ||||
|             ) | ||||
|         for entity, parent_id in related_entities: | ||||
|         for value in related_entities: | ||||
|             set_committed_value( | ||||
|                 entity, | ||||
|                 value[0], | ||||
|                 self.prop.back_populates, | ||||
|                 backref_dict[self.local_values(entity)] | ||||
|                 backref_dict[self.local_values(value[0])] | ||||
|             ) | ||||
|  | ||||
|     def populate(self): | ||||
| @@ -257,6 +277,12 @@ class Fetcher(object): | ||||
|         Populate batch fetched entities to parent objects. | ||||
|         """ | ||||
|         for entity in self.path.entities: | ||||
|             # print ( | ||||
|             #     "setting committed value for ", | ||||
|             #     entity, | ||||
|             #     " using local values ", | ||||
|             #     self.local_values(entity) | ||||
|             # ) | ||||
|             set_committed_value( | ||||
|                 entity, | ||||
|                 self.prop.key, | ||||
| @@ -268,9 +294,25 @@ class Fetcher(object): | ||||
|  | ||||
|     @property | ||||
|     def condition(self): | ||||
|         return getattr(self.path.model, self.remote_column_name).in_( | ||||
|             self.local_values_list | ||||
|         ) | ||||
|         names = self.remote_column_names | ||||
|         if len(names) == 1: | ||||
|             return getattr(self.path.model, names[0]).in_( | ||||
|                 value[0] for value in self.local_values_list | ||||
|             ) | ||||
|         else: | ||||
|             conditions = [] | ||||
|             for entity in self.path.entities: | ||||
|                 conditions.append( | ||||
|                     sa.and_( | ||||
|                         *[ | ||||
|                             getattr(self.path.model, remote.name) | ||||
|                             == | ||||
|                             getattr(entity, local.name) | ||||
|                             for local, remote in self.prop.local_remote_pairs | ||||
|                         ] | ||||
|                     ) | ||||
|                 ) | ||||
|             return sa.or_(*conditions) | ||||
|  | ||||
|     def fetch(self): | ||||
|         for entity in self.related_entities: | ||||
| @@ -279,12 +321,40 @@ class Fetcher(object): | ||||
|  | ||||
| class ManyToManyFetcher(Fetcher): | ||||
|     @property | ||||
|     def remote_column_name(self): | ||||
|         for column in self.prop.remote_side: | ||||
|             for fk in column.foreign_keys: | ||||
|     def remote_column_names(self): | ||||
|         names = [] | ||||
|         for local, remote in self.prop.local_remote_pairs: | ||||
|             for fk in remote.foreign_keys: | ||||
|                 # TODO: make this support inherited tables | ||||
|                 if fk.column.table == self.path.parent_model.__table__: | ||||
|                     return fk.parent.name | ||||
|                     names.append(fk.parent.name) | ||||
|  | ||||
|         return names | ||||
|  | ||||
|     @property | ||||
|     def condition(self): | ||||
|         if len(self.remote_column_names) == 1: | ||||
|             return ( | ||||
|                 getattr(self.prop.secondary.c, self.remote_column_names[0]) | ||||
|                 .in_( | ||||
|                     [value[0] for value in self.local_values_list] | ||||
|                 ) | ||||
|             ) | ||||
|         else: | ||||
|             conditions = [] | ||||
|             for entity in self.path.entities: | ||||
|                 conditions.append( | ||||
|                     sa.and_( | ||||
|                         *[ | ||||
|                             getattr(self.prop.secondary.c, remote.name) | ||||
|                             == | ||||
|                             getattr(entity, local.name) | ||||
|                             for local, remote in self.prop.local_remote_pairs | ||||
|                             if remote.name in self.remote_column_names | ||||
|                         ] | ||||
|                     ) | ||||
|                 ) | ||||
|             return sa.or_(*conditions) | ||||
|  | ||||
|     @property | ||||
|     def related_entities(self): | ||||
| @@ -292,22 +362,23 @@ class ManyToManyFetcher(Fetcher): | ||||
|             self.path.session | ||||
|             .query( | ||||
|                 self.path.model, | ||||
|                 getattr(self.prop.secondary.c, self.remote_column_name) | ||||
|                 *[ | ||||
|                     getattr(self.prop.secondary.c, name) | ||||
|                     for name in self.remote_column_names | ||||
|                 ] | ||||
|             ) | ||||
|             .join( | ||||
|                 self.prop.secondary, self.prop.secondaryjoin | ||||
|             ) | ||||
|             .filter( | ||||
|                 getattr(self.prop.secondary.c, self.remote_column_name).in_( | ||||
|                     self.local_values_list | ||||
|                 ) | ||||
|                 self.condition | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def fetch(self): | ||||
|         for entity, parent_id in self.related_entities: | ||||
|             self.parent_dict[parent_id].append( | ||||
|                 entity | ||||
|         for value in self.related_entities: | ||||
|             self.parent_dict[tuple(value[1:])].append( | ||||
|                 value[0] | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @@ -317,11 +388,38 @@ class ManyToOneFetcher(Fetcher): | ||||
|         self.parent_dict = defaultdict(lambda: None) | ||||
|  | ||||
|     def append_entity(self, entity): | ||||
|         self.parent_dict[getattr(entity, self.remote_column_name)] = entity | ||||
|         #print 'appending entity ', entity, ' to key ', self.parent_key(entity) | ||||
|         self.parent_dict[self.parent_key(entity)] = entity | ||||
|  | ||||
|     @property | ||||
|     def remote_column_names(self): | ||||
|         names = [] | ||||
|         for local, remote in self.prop.local_remote_pairs: | ||||
|             names.append(remote.name) | ||||
|         return names | ||||
|  | ||||
|     @property | ||||
|     def local_column_names(self): | ||||
|         names = [] | ||||
|         for local, remote in self.prop.local_remote_pairs: | ||||
|             names.append(local.name) | ||||
|         return names | ||||
|  | ||||
|  | ||||
| class OneToManyFetcher(Fetcher): | ||||
|     def append_entity(self, entity): | ||||
|         self.parent_dict[getattr(entity, self.remote_column_name)].append( | ||||
|         #print 'appending entity ', entity, ' to key ', self.parent_key(entity) | ||||
|         self.parent_dict[self.parent_key(entity)].append( | ||||
|             entity | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def remote_column_names(self): | ||||
|         names = [] | ||||
|         for local, remote in self.prop.local_remote_pairs: | ||||
|             for fk in remote.foreign_keys: | ||||
|                 # TODO: make this support inherited tables | ||||
|                 if fk.column.table == self.path.parent_model.__table__: | ||||
|                     names.append(fk.parent.name) | ||||
|  | ||||
|         return names | ||||
|   | ||||
							
								
								
									
										111
									
								
								tests/batch_fetch/test_many_to_many_composite_keys.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								tests/batch_fetch/test_many_to_many_composite_keys.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy_utils import batch_fetch, with_backrefs | ||||
| from tests import TestCase | ||||
|  | ||||
|  | ||||
| class TestBatchFetchManyToManyCompositeRelationships(TestCase): | ||||
|     def create_models(self): | ||||
|         class Article(self.Base): | ||||
|             __tablename__ = 'article' | ||||
|             id1 = sa.Column(sa.Integer, primary_key=True) | ||||
|             id2 = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|  | ||||
|         article_tag = sa.Table( | ||||
|             'article_tag', | ||||
|             self.Base.metadata, | ||||
|             sa.Column( | ||||
|                 'article_id1', | ||||
|                 sa.Integer, | ||||
|             ), | ||||
|             sa.Column( | ||||
|                 'article_id2', | ||||
|                 sa.Integer, | ||||
|             ), | ||||
|             sa.Column( | ||||
|                 'tag_id1', | ||||
|                 sa.Integer, | ||||
|             ), | ||||
|             sa.Column( | ||||
|                 'tag_id2', | ||||
|                 sa.Integer, | ||||
|             ), | ||||
|             sa.ForeignKeyConstraint( | ||||
|                 ['article_id1', 'article_id2'], | ||||
|                 ['article.id1', 'article.id2'] | ||||
|             ), | ||||
|             sa.ForeignKeyConstraint( | ||||
|                 ['tag_id1', 'tag_id2'], | ||||
|                 ['tag.id1', 'tag.id2'] | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         class Tag(self.Base): | ||||
|             __tablename__ = 'tag' | ||||
|             id1 = sa.Column(sa.Integer, primary_key=True) | ||||
|             id2 = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|             articles = sa.orm.relationship( | ||||
|                 Article, | ||||
|                 secondary=article_tag, | ||||
|                 backref=sa.orm.backref( | ||||
|                     'tags', | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         self.Article = Article | ||||
|         self.Tag = Tag | ||||
|  | ||||
|     def setup_method(self, method): | ||||
|         TestCase.setup_method(self, method) | ||||
|         articles = [ | ||||
|             self.Article(id1=1, id2=2, name=u'Article 1'), | ||||
|             self.Article(id1=2, id2=2, name=u'Article 2'), | ||||
|             self.Article(id1=3, id2=3, name=u'Article 3'), | ||||
|             self.Article(id1=4, id2=3, name=u'Article 4'), | ||||
|             self.Article(id1=5, id2=3, name=u'Article 5') | ||||
|         ] | ||||
|         self.session.add_all(articles) | ||||
|         self.session.flush() | ||||
|  | ||||
|         tags = [ | ||||
|             self.Tag(id1=1, id2=2, name=u'Tag 1'), | ||||
|             self.Tag(id1=2, id2=3, name=u'Tag 2'), | ||||
|             self.Tag(id1=3, id2=4, name=u'Tag 3') | ||||
|         ] | ||||
|         articles[0].tags = tags | ||||
|         articles[3].tags = tags[1:] | ||||
|         self.session.commit() | ||||
|  | ||||
|     def test_deep_relationships(self): | ||||
|         articles = ( | ||||
|             self.session.query(self.Article) | ||||
|             .order_by(self.Article.id1).all() | ||||
|         ) | ||||
|         batch_fetch( | ||||
|             articles, | ||||
|             'tags' | ||||
|         ) | ||||
|         query_count = self.connection.query_count | ||||
|         assert articles[0].tags | ||||
|         articles[1].tags | ||||
|         assert articles[3].tags | ||||
|         assert self.connection.query_count == query_count | ||||
|  | ||||
|     def test_many_to_many_backref_population(self): | ||||
|         articles = ( | ||||
|             self.session.query(self.Article) | ||||
|             .order_by(self.Article.id1).all() | ||||
|         ) | ||||
|         batch_fetch( | ||||
|             articles, | ||||
|             with_backrefs('tags'), | ||||
|         ) | ||||
|         query_count = self.connection.query_count | ||||
|         tags = articles[0].tags | ||||
|         tags2 = articles[3].tags | ||||
|         tags[0].articles | ||||
|         tags2[0].articles | ||||
|         names = [article.name for article in tags[0].articles] | ||||
|         assert u'Article 1' in names | ||||
|         assert self.connection.query_count == query_count | ||||
		Reference in New Issue
	
	Block a user
	 Konsta Vesterinen
					Konsta Vesterinen