diff --git a/CHANGES.rst b/CHANGES.rst index 67c22ca..8e85aaa 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,10 +4,23 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.16.7 (2013-08-18) +^^^^^^^^^^^^^^^^^^^ + +- Added better handling of local column names in batch_fetch +- PasswordType gets default length even if no crypt context schemes provided + + +0.16.6 (2013-08-16) +^^^^^^^^^^^^^^^^^^^ + +- Rewritten batch_fetch schematics, new syntax for backref population + + 0.16.5 (2013-08-08) ^^^^^^^^^^^^^^^^^^^ -- Initial backref population forcing for batch_fetch +- Initial backref population forcing support for batch_fetch 0.16.4 (2013-08-08) diff --git a/setup.py b/setup.py index a4f68d9..960b7a7 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ for name, requirements in extras_require.items(): setup( name='SQLAlchemy-Utils', - version='0.16.5', + version='0.16.7', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen', diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 8cf3d49..42ca103 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -37,7 +37,7 @@ from .types import ( ) -__version__ = '0.16.5' +__version__ = '0.16.7' __all__ = ( diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index 2e99d62..7c8c1cf 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -5,6 +5,10 @@ from sqlalchemy.orm.session import object_session class with_backrefs(object): + """ + Marks given attribute path so that whenever its fetched with batch_fetch + the backref relations are force set too. + """ def __init__(self, attr_path): self.attr_path = attr_path @@ -70,49 +74,17 @@ def batch_fetch(entities, *attr_paths): """ if entities: - fetcher = BatchFetcher(entities) + fetcher = FetchingCoordinator(entities) for attr_path in attr_paths: fetcher(attr_path) -class BatchFetcher(object): +class FetchingCoordinator(object): def __init__(self, entities): self.entities = entities self.first = entities[0] - self.parent_ids = [entity.id for entity in entities] self.session = object_session(self.first) - def populate_backrefs(self, related_entities): - """ - Populates backrefs for given related entities. - """ - - backref_dict = dict( - (entity.id, []) for entity, parent_id in related_entities - ) - for entity, parent_id in related_entities: - backref_dict[entity.id].append( - self.session.query(self.first.__class__).get(parent_id) - ) - for entity, parent_id in related_entities: - set_committed_value( - entity, self.prop.back_populates, backref_dict[entity.id] - ) - - def populate_entities(self): - """ - Populate batch fetched entities to parent objects. - """ - for entity in self.entities: - set_committed_value( - entity, - self.prop.key, - self.parent_dict[entity.id] - ) - - if self.should_populate_backrefs: - self.populate_backrefs(self.related_entities) - def parse_attr_path(self, attr_path, should_populate_backrefs): if isinstance(attr_path, six.string_types): attrs = attr_path.split('.') @@ -146,21 +118,96 @@ class BatchFetcher(object): 'are supported.' ) - column_name = list(self.prop.remote_side)[0].name - - self.related_entities = ( - self.session.query(self.model) - .filter( - getattr(self.model, column_name).in_(self.parent_ids) + def fetcher(self, property_): + if not isinstance(property_, RelationshipProperty): + raise Exception( + 'Given attribute is not a relationship property.' ) + + if property_.secondary is not None: + return ManyToManyFetcher(self, property_) + else: + if property_.direction.name == 'MANYTOONE': + return ManyToOneFetcher(self, property_) + else: + return OneToManyFetcher(self, property_) + + def __call__(self, attr_path): + if isinstance(attr_path, with_backrefs): + self.should_populate_backrefs = True + attr_path = attr_path.attr_path + else: + self.should_populate_backrefs = False + + attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) + if not attr: + return + + fetcher = self.fetcher(attr.property) + fetcher.fetch() + fetcher.populate() + + +class Fetcher(object): + def __init__(self, coordinator, property_): + self.coordinator = coordinator + self.prop = property_ + self.model = self.prop.mapper.class_ + self.entities = coordinator.entities + self.first = self.entities[0] + self.session = object_session(self.first) + + self.parent_dict = dict( + (self.local_values(entity), []) + for entity in self.entities ) - for entity in self.related_entities: - self.parent_dict[getattr(entity, column_name)].append( - entity + @property + def local_values_list(self): + return [ + self.local_values(entity) + for entity in self.entities + ] + + def local_values(self, entity): + return getattr(entity, list(self.prop.local_columns)[0].name) + + def populate_backrefs(self, related_entities): + """ + Populates backrefs for given related entities. + """ + backref_dict = dict( + (self.local_values(entity), []) + for entity, parent_id in related_entities + ) + for entity, parent_id in related_entities: + backref_dict[self.local_values(entity)].append( + self.session.query(self.first.__class__).get(parent_id) + ) + for entity, parent_id in related_entities: + set_committed_value( + entity, + self.prop.back_populates, + backref_dict[self.local_values(entity)] ) - def fetch_association_entities(self): + def populate(self): + """ + Populate batch fetched entities to parent objects. + """ + for entity in self.entities: + set_committed_value( + entity, + self.prop.key, + self.parent_dict[self.local_values(entity)] + ) + + if self.coordinator.should_populate_backrefs: + self.populate_backrefs(self.related_entities) + + +class ManyToManyFetcher(Fetcher): + def fetch(self): column_name = None for column in self.prop.remote_side: for fk in column.foreign_keys: @@ -179,7 +226,7 @@ class BatchFetcher(object): ) .filter( getattr(self.prop.secondary.c, column_name).in_( - self.parent_ids + self.local_values_list ) ) ) @@ -188,30 +235,36 @@ class BatchFetcher(object): entity ) - def __call__(self, attr_path): - self.parent_dict = dict( - (entity.id, []) for entity in self.entities + +class ManyToOneFetcher(Fetcher): + def fetch(self): + column_name = list(self.prop.remote_side)[0].name + + self.related_entities = ( + self.session.query(self.model) + .filter( + getattr(self.model, column_name).in_(self.local_values_list) + ) ) - if isinstance(attr_path, with_backrefs): - self.should_populate_backrefs = True - attr_path = attr_path.attr_path - else: - self.should_populate_backrefs = False - attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) - if not attr: - return - - self.prop = attr.property - if not isinstance(self.prop, RelationshipProperty): - raise Exception( - 'Given attribute is not a relationship property.' + for entity in self.related_entities: + self.parent_dict[getattr(entity, column_name)].append( + entity ) - self.model = self.prop.mapper.class_ - if self.prop.secondary is None: - self.fetch_relation_entities() - else: - self.fetch_association_entities() - self.populate_entities() +class OneToManyFetcher(Fetcher): + def fetch(self): + column_name = list(self.prop.remote_side)[0].name + + self.related_entities = ( + self.session.query(self.model) + .filter( + getattr(self.model, column_name).in_(self.local_values_list) + ) + ) + + for entity in self.related_entities: + self.parent_dict[getattr(entity, column_name)].append( + entity + ) diff --git a/sqlalchemy_utils/types/password.py b/sqlalchemy_utils/types/password.py index 7e52d87..9c0f673 100644 --- a/sqlalchemy_utils/types/password.py +++ b/sqlalchemy_utils/types/password.py @@ -89,14 +89,17 @@ class PasswordType(types.TypeDecorator, ScalarCoercible): if max_length is None: # Calculate the largest possible encoded password. # name + rounds + salt + hash + ($ * 4) of largest hash - max_lengths = [] + max_lengths = [1024] for name in self.context.schemes(): scheme = getattr(__import__('passlib.hash').hash, name) length = 4 + len(scheme.name) length += len(str(getattr(scheme, 'max_rounds', ''))) length += scheme.max_salt_size or 0 - length += getattr(scheme, 'encoded_checksum_size', - scheme.checksum_size) + length += getattr( + scheme, + 'encoded_checksum_size', + scheme.checksum_size + ) max_lengths.append(length) # Set the max_length to the maximum calculated max length. diff --git a/tests/batch_fetch/test_deep_relationships.py b/tests/batch_fetch/test_deep_relationships.py index b837c29..a4b6737 100644 --- a/tests/batch_fetch/test_deep_relationships.py +++ b/tests/batch_fetch/test_deep_relationships.py @@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch, with_backrefs from tests import TestCase -class TestBatchFetch(TestCase): +class TestBatchFetchDeepRelationships(TestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' diff --git a/tests/batch_fetch/test_join_table_inheritance.py b/tests/batch_fetch/test_join_table_inheritance.py index dfab812..95abdec 100644 --- a/tests/batch_fetch/test_join_table_inheritance.py +++ b/tests/batch_fetch/test_join_table_inheritance.py @@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch from tests import TestCase -class TestBatchFetch(TestCase): +class TestBatchFetchAssociations(TestCase): def create_models(self): class Category(self.Base): __tablename__ = 'category' diff --git a/tests/batch_fetch/test_many_to_one_relationships.py b/tests/batch_fetch/test_many_to_one_relationships.py new file mode 100644 index 0000000..36b0af4 --- /dev/null +++ b/tests/batch_fetch/test_many_to_one_relationships.py @@ -0,0 +1,55 @@ +import sqlalchemy as sa +from sqlalchemy_utils import batch_fetch +from tests import TestCase + + +class TestBatchFetchManyToOneRelationships(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + + author = sa.orm.relationship( + User, + backref=sa.orm.backref( + 'articles' + ) + ) + + self.User = User + self.Article = Article + + def setup_method(self, method): + TestCase.setup_method(self, method) + articles = [ + self.Article( + id=1, + name=u'Article 1', + author=self.User(id=333, name=u'John') + ), + self.Article( + id=2, + name=u'Article 2', + author=self.User(id=334, name=u'Matt') + ), + ] + self.session.add_all(articles) + self.session.commit() + + def test_supports_relationship_attributes(self): + articles = self.session.query(self.Article).all() + batch_fetch( + articles, + 'author' + ) + query_count = self.connection.query_count + assert articles[0].author # no lazy load should occur + assert articles[1].author # no lazy load should occur + assert self.connection.query_count == query_count diff --git a/tests/batch_fetch/test_simple_relationships.py b/tests/batch_fetch/test_one_to_many_relationships.py similarity index 97% rename from tests/batch_fetch/test_simple_relationships.py rename to tests/batch_fetch/test_one_to_many_relationships.py index 03dfed7..6ae00ea 100644 --- a/tests/batch_fetch/test_simple_relationships.py +++ b/tests/batch_fetch/test_one_to_many_relationships.py @@ -4,7 +4,7 @@ from sqlalchemy_utils import batch_fetch from tests import TestCase -class TestBatchFetch(TestCase): +class TestBatchFetchOneToManyRelationships(TestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' diff --git a/tests/types/test_password.py b/tests/types/test_password.py index da4209c..5b1907b 100644 --- a/tests/types/test_password.py +++ b/tests/types/test_password.py @@ -8,7 +8,6 @@ from sqlalchemy_utils import Password, PasswordType @mark.skipif('password.passlib is None') class TestPasswordType(TestCase): - def create_models(self): class User(self.Base): __tablename__ = 'user' @@ -86,3 +85,6 @@ class TestPasswordType(TestCase): expected_length += 4 assert impl.length == expected_length + + def test_without_schemes(self): + assert PasswordType(schemes=[]).length == 1024