Merge branch 'master' of github.com:kvesteri/sqlalchemy-utils

This commit is contained in:
Ryan Leckey
2013-08-19 10:13:46 -07:00
10 changed files with 203 additions and 77 deletions

View File

@@ -4,10 +4,23 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release. Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.16.7 (2013-08-18)
^^^^^^^^^^^^^^^^^^^
- Added better handling of local column names in batch_fetch
- PasswordType gets default length even if no crypt context schemes provided
0.16.6 (2013-08-16)
^^^^^^^^^^^^^^^^^^^
- Rewritten batch_fetch schematics, new syntax for backref population
0.16.5 (2013-08-08) 0.16.5 (2013-08-08)
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
- Initial backref population forcing for batch_fetch - Initial backref population forcing support for batch_fetch
0.16.4 (2013-08-08) 0.16.4 (2013-08-08)

View File

@@ -55,7 +55,7 @@ for name, requirements in extras_require.items():
setup( setup(
name='SQLAlchemy-Utils', name='SQLAlchemy-Utils',
version='0.16.5', version='0.16.7',
url='https://github.com/kvesteri/sqlalchemy-utils', url='https://github.com/kvesteri/sqlalchemy-utils',
license='BSD', license='BSD',
author='Konsta Vesterinen', author='Konsta Vesterinen',

View File

@@ -37,7 +37,7 @@ from .types import (
) )
__version__ = '0.16.5' __version__ = '0.16.7'
__all__ = ( __all__ = (

View File

@@ -5,6 +5,10 @@ from sqlalchemy.orm.session import object_session
class with_backrefs(object): class with_backrefs(object):
"""
Marks given attribute path so that whenever its fetched with batch_fetch
the backref relations are force set too.
"""
def __init__(self, attr_path): def __init__(self, attr_path):
self.attr_path = attr_path self.attr_path = attr_path
@@ -70,49 +74,17 @@ def batch_fetch(entities, *attr_paths):
""" """
if entities: if entities:
fetcher = BatchFetcher(entities) fetcher = FetchingCoordinator(entities)
for attr_path in attr_paths: for attr_path in attr_paths:
fetcher(attr_path) fetcher(attr_path)
class BatchFetcher(object): class FetchingCoordinator(object):
def __init__(self, entities): def __init__(self, entities):
self.entities = entities self.entities = entities
self.first = entities[0] self.first = entities[0]
self.parent_ids = [entity.id for entity in entities]
self.session = object_session(self.first) self.session = object_session(self.first)
def populate_backrefs(self, related_entities):
"""
Populates backrefs for given related entities.
"""
backref_dict = dict(
(entity.id, []) for entity, parent_id in related_entities
)
for entity, parent_id in related_entities:
backref_dict[entity.id].append(
self.session.query(self.first.__class__).get(parent_id)
)
for entity, parent_id in related_entities:
set_committed_value(
entity, self.prop.back_populates, backref_dict[entity.id]
)
def populate_entities(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[entity.id]
)
if self.should_populate_backrefs:
self.populate_backrefs(self.related_entities)
def parse_attr_path(self, attr_path, should_populate_backrefs): def parse_attr_path(self, attr_path, should_populate_backrefs):
if isinstance(attr_path, six.string_types): if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.') attrs = attr_path.split('.')
@@ -146,21 +118,96 @@ class BatchFetcher(object):
'are supported.' 'are supported.'
) )
column_name = list(self.prop.remote_side)[0].name def fetcher(self, property_):
if not isinstance(property_, RelationshipProperty):
self.related_entities = ( raise Exception(
self.session.query(self.model) 'Given attribute is not a relationship property.'
.filter(
getattr(self.model, column_name).in_(self.parent_ids)
) )
if property_.secondary is not None:
return ManyToManyFetcher(self, property_)
else:
if property_.direction.name == 'MANYTOONE':
return ManyToOneFetcher(self, property_)
else:
return OneToManyFetcher(self, property_)
def __call__(self, attr_path):
if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True
attr_path = attr_path.attr_path
else:
self.should_populate_backrefs = False
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
if not attr:
return
fetcher = self.fetcher(attr.property)
fetcher.fetch()
fetcher.populate()
class Fetcher(object):
def __init__(self, coordinator, property_):
self.coordinator = coordinator
self.prop = property_
self.model = self.prop.mapper.class_
self.entities = coordinator.entities
self.first = self.entities[0]
self.session = object_session(self.first)
self.parent_dict = dict(
(self.local_values(entity), [])
for entity in self.entities
) )
for entity in self.related_entities: @property
self.parent_dict[getattr(entity, column_name)].append( def local_values_list(self):
entity return [
self.local_values(entity)
for entity in self.entities
]
def local_values(self, entity):
return getattr(entity, list(self.prop.local_columns)[0].name)
def populate_backrefs(self, related_entities):
"""
Populates backrefs for given related entities.
"""
backref_dict = dict(
(self.local_values(entity), [])
for entity, parent_id in related_entities
)
for entity, parent_id in related_entities:
backref_dict[self.local_values(entity)].append(
self.session.query(self.first.__class__).get(parent_id)
)
for entity, parent_id in related_entities:
set_committed_value(
entity,
self.prop.back_populates,
backref_dict[self.local_values(entity)]
) )
def fetch_association_entities(self): def populate(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
)
if self.coordinator.should_populate_backrefs:
self.populate_backrefs(self.related_entities)
class ManyToManyFetcher(Fetcher):
def fetch(self):
column_name = None column_name = None
for column in self.prop.remote_side: for column in self.prop.remote_side:
for fk in column.foreign_keys: for fk in column.foreign_keys:
@@ -179,7 +226,7 @@ class BatchFetcher(object):
) )
.filter( .filter(
getattr(self.prop.secondary.c, column_name).in_( getattr(self.prop.secondary.c, column_name).in_(
self.parent_ids self.local_values_list
) )
) )
) )
@@ -188,30 +235,36 @@ class BatchFetcher(object):
entity entity
) )
def __call__(self, attr_path):
self.parent_dict = dict( class ManyToOneFetcher(Fetcher):
(entity.id, []) for entity in self.entities def fetch(self):
column_name = list(self.prop.remote_side)[0].name
self.related_entities = (
self.session.query(self.model)
.filter(
getattr(self.model, column_name).in_(self.local_values_list)
)
) )
if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True
attr_path = attr_path.attr_path
else:
self.should_populate_backrefs = False
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) for entity in self.related_entities:
if not attr: self.parent_dict[getattr(entity, column_name)].append(
return entity
self.prop = attr.property
if not isinstance(self.prop, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
) )
self.model = self.prop.mapper.class_
if self.prop.secondary is None: class OneToManyFetcher(Fetcher):
self.fetch_relation_entities() def fetch(self):
else: column_name = list(self.prop.remote_side)[0].name
self.fetch_association_entities()
self.populate_entities() self.related_entities = (
self.session.query(self.model)
.filter(
getattr(self.model, column_name).in_(self.local_values_list)
)
)
for entity in self.related_entities:
self.parent_dict[getattr(entity, column_name)].append(
entity
)

View File

@@ -89,14 +89,17 @@ class PasswordType(types.TypeDecorator, ScalarCoercible):
if max_length is None: if max_length is None:
# Calculate the largest possible encoded password. # Calculate the largest possible encoded password.
# name + rounds + salt + hash + ($ * 4) of largest hash # name + rounds + salt + hash + ($ * 4) of largest hash
max_lengths = [] max_lengths = [1024]
for name in self.context.schemes(): for name in self.context.schemes():
scheme = getattr(__import__('passlib.hash').hash, name) scheme = getattr(__import__('passlib.hash').hash, name)
length = 4 + len(scheme.name) length = 4 + len(scheme.name)
length += len(str(getattr(scheme, 'max_rounds', ''))) length += len(str(getattr(scheme, 'max_rounds', '')))
length += scheme.max_salt_size or 0 length += scheme.max_salt_size or 0
length += getattr(scheme, 'encoded_checksum_size', length += getattr(
scheme.checksum_size) scheme,
'encoded_checksum_size',
scheme.checksum_size
)
max_lengths.append(length) max_lengths.append(length)
# Set the max_length to the maximum calculated max length. # Set the max_length to the maximum calculated max length.

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch, with_backrefs
from tests import TestCase from tests import TestCase
class TestBatchFetch(TestCase): class TestBatchFetchDeepRelationships(TestCase):
def create_models(self): def create_models(self):
class User(self.Base): class User(self.Base):
__tablename__ = 'user' __tablename__ = 'user'

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch
from tests import TestCase from tests import TestCase
class TestBatchFetch(TestCase): class TestBatchFetchAssociations(TestCase):
def create_models(self): def create_models(self):
class Category(self.Base): class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'

View File

@@ -0,0 +1,55 @@
import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch
from tests import TestCase
class TestBatchFetchManyToOneRelationships(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(
User,
backref=sa.orm.backref(
'articles'
)
)
self.User = User
self.Article = Article
def setup_method(self, method):
TestCase.setup_method(self, method)
articles = [
self.Article(
id=1,
name=u'Article 1',
author=self.User(id=333, name=u'John')
),
self.Article(
id=2,
name=u'Article 2',
author=self.User(id=334, name=u'Matt')
),
]
self.session.add_all(articles)
self.session.commit()
def test_supports_relationship_attributes(self):
articles = self.session.query(self.Article).all()
batch_fetch(
articles,
'author'
)
query_count = self.connection.query_count
assert articles[0].author # no lazy load should occur
assert articles[1].author # no lazy load should occur
assert self.connection.query_count == query_count

View File

@@ -4,7 +4,7 @@ from sqlalchemy_utils import batch_fetch
from tests import TestCase from tests import TestCase
class TestBatchFetch(TestCase): class TestBatchFetchOneToManyRelationships(TestCase):
def create_models(self): def create_models(self):
class User(self.Base): class User(self.Base):
__tablename__ = 'user' __tablename__ = 'user'

View File

@@ -8,7 +8,6 @@ from sqlalchemy_utils import Password, PasswordType
@mark.skipif('password.passlib is None') @mark.skipif('password.passlib is None')
class TestPasswordType(TestCase): class TestPasswordType(TestCase):
def create_models(self): def create_models(self):
class User(self.Base): class User(self.Base):
__tablename__ = 'user' __tablename__ = 'user'
@@ -86,3 +85,6 @@ class TestPasswordType(TestCase):
expected_length += 4 expected_length += 4
assert impl.length == expected_length assert impl.length == expected_length
def test_without_schemes(self):
assert PasswordType(schemes=[]).length == 1024