Add missing files

This commit is contained in:
Konsta Vesterinen
2013-12-04 15:54:57 +02:00
parent 0d92340f2b
commit 7f47f94461
2 changed files with 631 additions and 0 deletions

458
sqlalchemy_utils/batch.py Normal file
View File

@@ -0,0 +1,458 @@
from collections import defaultdict
from itertools import chain
import six
import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import (
set_committed_value, InstrumentedAttribute
)
from sqlalchemy.orm.session import object_session
from sqlalchemy_utils.generic import (
GenericRelationshipProperty, class_from_table_name
)
class PathException(Exception):
pass
class with_backrefs(object):
"""
Marks given attribute path so that whenever its fetched with batch_fetch
the backref relations are force set too. Very useful when dealing with
certain many-to-many relationship scenarios.
"""
def __init__(self, path):
self.path = path
class Path(object):
"""
A class that represents an attribute path.
"""
def __init__(self, entities, prop, populate_backrefs=False):
self.property = prop
self.entities = entities
self.populate_backrefs = populate_backrefs
if (not isinstance(self.property, RelationshipProperty) and
not isinstance(self.property, GenericRelationshipProperty)):
raise PathException(
'Given attribute is not a relationship property.'
)
self.fetcher = self.fetcher_class(self)
@property
def session(self):
return object_session(self.entities[0])
@property
def parent_model(self):
return self.entities[0].__class__
@property
def model(self):
return self.property.mapper.class_
@classmethod
def parse(cls, entities, path, populate_backrefs=False):
if isinstance(path, six.string_types):
attrs = path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
if not related_entities:
return
subpath = '.'.join(attrs[1:])
return Path.parse(related_entities, subpath, populate_backrefs)
else:
attr = getattr(
entities[0].__class__, attrs[0]
)
elif isinstance(path, InstrumentedAttribute):
attr = path
else:
raise PathException('Unknown path type.')
return Path(entities, attr.property, populate_backrefs)
@property
def fetcher_class(self):
if isinstance(self.property, GenericRelationshipProperty):
return GenericRelationshipFetcher
else:
if self.property.secondary is not None:
return ManyToManyFetcher
else:
if self.property.direction.name == 'MANYTOONE':
return ManyToOneFetcher
else:
return OneToManyFetcher
class CompositePath(object):
def __init__(self, *paths):
self.paths = paths
def batch_fetch(entities, *attr_paths):
"""
Batch fetch given relationship attribute for collection of entities.
This function is in many cases a valid alternative for SQLAlchemy's
subqueryload and performs lot better.
:param entities: list of entities of the same type
:param attr_paths:
List of either InstrumentedAttribute objects or a strings representing
the name of the instrumented attribute
Example::
from sqlalchemy_utils import batch_fetch
users = session.query(User).limit(20).all()
batch_fetch(users, User.phonenumbers)
Function also accepts strings as attribute names: ::
users = session.query(User).limit(20).all()
batch_fetch(users, 'phonenumbers')
Multiple attributes may be provided: ::
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
'teams.players.user_groups'
)
You can also force populate backrefs: ::
from sqlalchemy_utils import with_backrefs
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
with_backrefs('teams.players.user_groups')
)
"""
if entities:
for path in attr_paths:
fetcher = fetcher_factory(entities, path)
if fetcher:
fetcher.fetch()
fetcher.populate()
def fetcher_factory(entities, path):
populate_backrefs = False
if isinstance(path, with_backrefs):
path = path.path
populate_backrefs = True
if isinstance(path, CompositePath):
fetchers = []
for path in path.paths:
path = Path.parse(entities, path, populate_backrefs)
if path:
fetchers.append(
path.fetcher
)
return CompositeFetcher(*fetchers)
else:
path = Path.parse(entities, path, populate_backrefs)
if path:
return path.fetcher
class CompositeFetcher(object):
def __init__(self, *fetchers):
if not all(
fetchers[0].path.model == fetcher.path.model
for fetcher in fetchers
):
raise PathException(
'Each relationship property must have the same class when '
'using CompositeFetcher.'
)
self.fetchers = fetchers
@property
def session(self):
return self.fetchers[0].path.session
@property
def model(self):
return self.fetchers[0].path.model
@property
def condition(self):
return sa.or_(
*[fetcher.condition for fetcher in self.fetchers]
)
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
def fetch(self):
for entity in self.related_entities:
for fetcher in self.fetchers:
if any(
getattr(entity, name)
for name in fetcher.remote_column_names
):
fetcher.append_entity(entity)
def populate(self):
for fetcher in self.fetchers:
fetcher.populate()
class AbstractFetcher(object):
@property
def local_values_list(self):
return [
self.local_values(entity)
for entity in self.path.entities
]
def local_values(self, entity):
return tuple(
getattr(entity, name)
for name in self.local_column_names
)
class Fetcher(AbstractFetcher):
def __init__(self, path):
self.path = path
self.prop = self.path.property
if self.prop.uselist:
self.parent_dict = defaultdict(list)
else:
self.parent_dict = defaultdict(lambda: None)
def parent_key(self, entity):
return tuple(
getattr(entity, name)
for name in self.remote_column_names
)
@property
def relation_query_base(self):
return self.path.session.query(self.path.model)
@property
def related_entities(self):
return self.relation_query_base.filter(self.condition)
@property
def local_column_names(self):
return [local.name for local, remote in self.prop.local_remote_pairs]
def populate_backrefs(self, related_entities):
"""
Populates backrefs for given related entities.
"""
backref_dict = dict(
(self.local_values(value[0]), [])
for value in related_entities
)
for value in related_entities:
backref_dict[self.local_values(value[0])].append(
self.path.session.query(self.path.parent_model).get(
tuple(value[1:])
)
)
for value in related_entities:
set_committed_value(
value[0],
self.prop.back_populates,
backref_dict[self.local_values(value[0])]
)
def populate(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.path.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
)
if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
def remote(self):
return self.path.model
@property
def condition(self):
names = list(self.remote_column_names)
if len(names) == 1:
return getattr(self.remote, names[0]).in_(
value[0] for value in self.local_values_list
)
elif len(names) > 1:
conditions = []
for entity in self.path.entities:
conditions.append(
sa.and_(
*[
getattr(self.remote, remote.name)
==
getattr(entity, local.name)
for local, remote in self.prop.local_remote_pairs
if remote in self.remote_column_names
]
)
)
return sa.or_(*conditions)
else:
raise PathException(
'Could not obtain remote column names.'
)
def fetch(self):
for entity in self.related_entities:
self.append_entity(entity)
@property
def remote_column_names(self):
for local, remote in self.prop.local_remote_pairs:
yield remote.name
class GenericRelationshipFetcher(AbstractFetcher):
def __init__(self, path):
self.path = path
self.prop = self.path.property
self.parent_dict = defaultdict(lambda: None)
def fetch(self):
for entity in self.related_entities:
self.append_entity(entity)
def parent_key(self, entity):
return (entity.__tablename__, getattr(entity, 'id'))
def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)] = entity
@property
def local_column_names(self):
return (self.prop._discriminator_col.key, self.prop._id_col.key)
def populate(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.path.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
)
@property
def related_entities(self):
classes = []
id_dict = defaultdict(list)
for entity in self.path.entities:
discriminator = getattr(entity, self.prop._discriminator_col.key)
id_dict[discriminator].append(
getattr(entity, self.prop._id_col.key)
)
return chain(*self._queries(sa.inspect(entity), id_dict))
def _queries(self, state, id_dict):
for discriminator, ids in six.iteritems(id_dict):
class_ = class_from_table_name(
state, discriminator
)
yield self.path.session.query(
class_
).filter(
class_.id.in_(ids)
)
class ManyToManyFetcher(Fetcher):
@property
def remote(self):
return self.prop.secondary.c
@property
def local_column_names(self):
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in self.prop.parent.tables:
yield local.name
@property
def remote_column_names(self):
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in self.prop.parent.tables:
yield remote.name
@property
def relation_query_base(self):
return (
self.path.session
.query(
self.path.model,
*[
getattr(self.prop.secondary.c, name)
for name in self.remote_column_names
]
)
.join(
self.prop.secondary, self.prop.secondaryjoin
)
)
def fetch(self):
for value in self.related_entities:
self.parent_dict[tuple(value[1:])].append(
value[0]
)
class ManyToOneFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)] = entity
class OneToManyFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)].append(
entity
)

View File

@@ -0,0 +1,173 @@
import sqlalchemy as sa
from collections import defaultdict
def remove_property(class_, name):
"""
**Experimental function**
Remove property from declarative class
"""
mapper = class_.mapper
table = class_.__table__
columns = class_.mapper.c
column = columns[name]
del columns._data[name]
del mapper.columns[name]
columns._all_cols.remove(column)
mapper._cols_by_table[table].remove(column)
mapper.class_manager.uninstrument_attribute(name)
del mapper._props[name]
def primary_keys(class_):
"""
Returns all primary keys for given declarative class.
"""
for column in class_.__table__.c:
if column.primary_key:
yield column
def table_name(obj):
"""
Return table name of given target, declarative class or the
table name where the declarative attribute is bound to.
"""
class_ = getattr(obj, 'class_', obj)
try:
return class_.__tablename__
except AttributeError:
pass
try:
return class_.__table__.name
except AttributeError:
pass
def declarative_base(model):
"""
Returns the declarative base for given model class.
:param model: SQLAlchemy declarative model
"""
for parent in model.__bases__:
try:
parent.metadata
return declarative_base(parent)
except AttributeError:
pass
return model
def has_changes(obj, attr):
"""
Simple shortcut function for checking if given attribute of given
declarative model object has changed during the transaction.
::
from sqlalchemy_utils import has_changes
user = User()
has_changes(user, 'name') # False
user.name = u'someone'
has_changes(user, 'name') # True
:param obj: SQLAlchemy declarative model object
:param attr: Name of the attribute
"""
return (
sa.inspect(obj)
.attrs
.get(attr)
.history
.has_changes()
)
def identity(obj):
"""
Return the identity of given sqlalchemy declarative model instance as a
tuple. This differs from obj._sa_instance_state.identity in a way that it
always returns the identity even if object is still in transient state (
new object that is not yet persisted into database).
::
from sqlalchemy import inspect
from sqlalchemy_utils import identity
user = User(name=u'John Matrix')
session.add(user)
identity(user) # None
inspect(user).identity # None
session.flush() # User now has id but is still in transient state
identity(user) # (1,)
inspect(user).identity # None
session.commit()
identity(user) # (1,)
inspect(user).identity # (1, )
.. versionadded: 0.21.0
:param obj: SQLAlchemy declarative model object
"""
id_ = []
for column in sa.inspect(obj.__class__).columns:
if column.primary_key:
id_.append(getattr(obj, column.name))
if all(value is None for value in id_):
return None
else:
return tuple(id_)
def naturally_equivalent(obj, obj2):
"""
Returns whether or not two given SQLAlchemy declarative instances are
naturally equivalent (all their non primary key properties are equivalent).
::
from sqlalchemy_utils import naturally_equivalent
user = User(name=u'someone')
user2 = User(name=u'someone')
user == user2 # False
naturally_equivalent(user, user2) # True
:param obj: SQLAlchemy declarative model object
:param obj2: SQLAlchemy declarative model object to compare with `obj`
"""
for prop in sa.inspect(obj.__class__).iterate_properties:
if not isinstance(prop, sa.orm.ColumnProperty):
continue
if prop.columns[0].primary_key:
continue
if not (getattr(obj, prop.key) == getattr(obj2, prop.key)):
return False
return True