Add missing files
This commit is contained in:
458
sqlalchemy_utils/batch.py
Normal file
458
sqlalchemy_utils/batch.py
Normal file
@@ -0,0 +1,458 @@
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import RelationshipProperty
|
||||
from sqlalchemy.orm.attributes import (
|
||||
set_committed_value, InstrumentedAttribute
|
||||
)
|
||||
from sqlalchemy.orm.session import object_session
|
||||
from sqlalchemy_utils.generic import (
|
||||
GenericRelationshipProperty, class_from_table_name
|
||||
)
|
||||
|
||||
|
||||
class PathException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class with_backrefs(object):
|
||||
"""
|
||||
Marks given attribute path so that whenever its fetched with batch_fetch
|
||||
the backref relations are force set too. Very useful when dealing with
|
||||
certain many-to-many relationship scenarios.
|
||||
"""
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
|
||||
class Path(object):
|
||||
"""
|
||||
A class that represents an attribute path.
|
||||
"""
|
||||
def __init__(self, entities, prop, populate_backrefs=False):
|
||||
self.property = prop
|
||||
self.entities = entities
|
||||
self.populate_backrefs = populate_backrefs
|
||||
if (not isinstance(self.property, RelationshipProperty) and
|
||||
not isinstance(self.property, GenericRelationshipProperty)):
|
||||
raise PathException(
|
||||
'Given attribute is not a relationship property.'
|
||||
)
|
||||
self.fetcher = self.fetcher_class(self)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return object_session(self.entities[0])
|
||||
|
||||
@property
|
||||
def parent_model(self):
|
||||
return self.entities[0].__class__
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.property.mapper.class_
|
||||
|
||||
@classmethod
|
||||
def parse(cls, entities, path, populate_backrefs=False):
|
||||
if isinstance(path, six.string_types):
|
||||
attrs = path.split('.')
|
||||
|
||||
if len(attrs) > 1:
|
||||
related_entities = []
|
||||
for entity in entities:
|
||||
related_entities.extend(getattr(entity, attrs[0]))
|
||||
|
||||
if not related_entities:
|
||||
return
|
||||
subpath = '.'.join(attrs[1:])
|
||||
return Path.parse(related_entities, subpath, populate_backrefs)
|
||||
else:
|
||||
attr = getattr(
|
||||
entities[0].__class__, attrs[0]
|
||||
)
|
||||
elif isinstance(path, InstrumentedAttribute):
|
||||
attr = path
|
||||
else:
|
||||
raise PathException('Unknown path type.')
|
||||
|
||||
return Path(entities, attr.property, populate_backrefs)
|
||||
|
||||
@property
|
||||
def fetcher_class(self):
|
||||
if isinstance(self.property, GenericRelationshipProperty):
|
||||
return GenericRelationshipFetcher
|
||||
else:
|
||||
if self.property.secondary is not None:
|
||||
return ManyToManyFetcher
|
||||
else:
|
||||
if self.property.direction.name == 'MANYTOONE':
|
||||
return ManyToOneFetcher
|
||||
else:
|
||||
return OneToManyFetcher
|
||||
|
||||
|
||||
class CompositePath(object):
|
||||
def __init__(self, *paths):
|
||||
self.paths = paths
|
||||
|
||||
|
||||
def batch_fetch(entities, *attr_paths):
|
||||
"""
|
||||
Batch fetch given relationship attribute for collection of entities.
|
||||
|
||||
This function is in many cases a valid alternative for SQLAlchemy's
|
||||
subqueryload and performs lot better.
|
||||
|
||||
:param entities: list of entities of the same type
|
||||
:param attr_paths:
|
||||
List of either InstrumentedAttribute objects or a strings representing
|
||||
the name of the instrumented attribute
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import batch_fetch
|
||||
|
||||
|
||||
users = session.query(User).limit(20).all()
|
||||
|
||||
batch_fetch(users, User.phonenumbers)
|
||||
|
||||
|
||||
Function also accepts strings as attribute names: ::
|
||||
|
||||
|
||||
users = session.query(User).limit(20).all()
|
||||
|
||||
batch_fetch(users, 'phonenumbers')
|
||||
|
||||
|
||||
Multiple attributes may be provided: ::
|
||||
|
||||
|
||||
clubs = session.query(Club).limit(20).all()
|
||||
|
||||
batch_fetch(
|
||||
clubs,
|
||||
'teams',
|
||||
'teams.players',
|
||||
'teams.players.user_groups'
|
||||
)
|
||||
|
||||
You can also force populate backrefs: ::
|
||||
|
||||
|
||||
from sqlalchemy_utils import with_backrefs
|
||||
|
||||
|
||||
clubs = session.query(Club).limit(20).all()
|
||||
|
||||
batch_fetch(
|
||||
clubs,
|
||||
'teams',
|
||||
'teams.players',
|
||||
with_backrefs('teams.players.user_groups')
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
if entities:
|
||||
for path in attr_paths:
|
||||
fetcher = fetcher_factory(entities, path)
|
||||
if fetcher:
|
||||
fetcher.fetch()
|
||||
fetcher.populate()
|
||||
|
||||
|
||||
def fetcher_factory(entities, path):
|
||||
populate_backrefs = False
|
||||
if isinstance(path, with_backrefs):
|
||||
path = path.path
|
||||
populate_backrefs = True
|
||||
|
||||
if isinstance(path, CompositePath):
|
||||
fetchers = []
|
||||
for path in path.paths:
|
||||
path = Path.parse(entities, path, populate_backrefs)
|
||||
if path:
|
||||
fetchers.append(
|
||||
path.fetcher
|
||||
)
|
||||
|
||||
return CompositeFetcher(*fetchers)
|
||||
else:
|
||||
path = Path.parse(entities, path, populate_backrefs)
|
||||
if path:
|
||||
return path.fetcher
|
||||
|
||||
|
||||
class CompositeFetcher(object):
|
||||
def __init__(self, *fetchers):
|
||||
if not all(
|
||||
fetchers[0].path.model == fetcher.path.model
|
||||
for fetcher in fetchers
|
||||
):
|
||||
raise PathException(
|
||||
'Each relationship property must have the same class when '
|
||||
'using CompositeFetcher.'
|
||||
)
|
||||
self.fetchers = fetchers
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.fetchers[0].path.session
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.fetchers[0].path.model
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
return sa.or_(
|
||||
*[fetcher.condition for fetcher in self.fetchers]
|
||||
)
|
||||
|
||||
@property
|
||||
def related_entities(self):
|
||||
return self.session.query(self.model).filter(self.condition)
|
||||
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
for fetcher in self.fetchers:
|
||||
if any(
|
||||
getattr(entity, name)
|
||||
for name in fetcher.remote_column_names
|
||||
):
|
||||
fetcher.append_entity(entity)
|
||||
|
||||
def populate(self):
|
||||
for fetcher in self.fetchers:
|
||||
fetcher.populate()
|
||||
|
||||
|
||||
class AbstractFetcher(object):
|
||||
@property
|
||||
def local_values_list(self):
|
||||
return [
|
||||
self.local_values(entity)
|
||||
for entity in self.path.entities
|
||||
]
|
||||
|
||||
def local_values(self, entity):
|
||||
return tuple(
|
||||
getattr(entity, name)
|
||||
for name in self.local_column_names
|
||||
)
|
||||
|
||||
|
||||
class Fetcher(AbstractFetcher):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.prop = self.path.property
|
||||
if self.prop.uselist:
|
||||
self.parent_dict = defaultdict(list)
|
||||
else:
|
||||
self.parent_dict = defaultdict(lambda: None)
|
||||
|
||||
def parent_key(self, entity):
|
||||
return tuple(
|
||||
getattr(entity, name)
|
||||
for name in self.remote_column_names
|
||||
)
|
||||
|
||||
@property
|
||||
def relation_query_base(self):
|
||||
return self.path.session.query(self.path.model)
|
||||
|
||||
@property
|
||||
def related_entities(self):
|
||||
return self.relation_query_base.filter(self.condition)
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
return [local.name for local, remote in self.prop.local_remote_pairs]
|
||||
|
||||
def populate_backrefs(self, related_entities):
|
||||
"""
|
||||
Populates backrefs for given related entities.
|
||||
"""
|
||||
backref_dict = dict(
|
||||
(self.local_values(value[0]), [])
|
||||
for value in related_entities
|
||||
)
|
||||
for value in related_entities:
|
||||
backref_dict[self.local_values(value[0])].append(
|
||||
self.path.session.query(self.path.parent_model).get(
|
||||
tuple(value[1:])
|
||||
)
|
||||
)
|
||||
for value in related_entities:
|
||||
set_committed_value(
|
||||
value[0],
|
||||
self.prop.back_populates,
|
||||
backref_dict[self.local_values(value[0])]
|
||||
)
|
||||
|
||||
def populate(self):
|
||||
"""
|
||||
Populate batch fetched entities to parent objects.
|
||||
"""
|
||||
for entity in self.path.entities:
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[self.local_values(entity)]
|
||||
)
|
||||
|
||||
if self.path.populate_backrefs:
|
||||
self.populate_backrefs(self.related_entities)
|
||||
|
||||
@property
|
||||
def remote(self):
|
||||
return self.path.model
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
names = list(self.remote_column_names)
|
||||
if len(names) == 1:
|
||||
return getattr(self.remote, names[0]).in_(
|
||||
value[0] for value in self.local_values_list
|
||||
)
|
||||
elif len(names) > 1:
|
||||
conditions = []
|
||||
for entity in self.path.entities:
|
||||
conditions.append(
|
||||
sa.and_(
|
||||
*[
|
||||
getattr(self.remote, remote.name)
|
||||
==
|
||||
getattr(entity, local.name)
|
||||
for local, remote in self.prop.local_remote_pairs
|
||||
if remote in self.remote_column_names
|
||||
]
|
||||
)
|
||||
)
|
||||
return sa.or_(*conditions)
|
||||
else:
|
||||
raise PathException(
|
||||
'Could not obtain remote column names.'
|
||||
)
|
||||
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
self.append_entity(entity)
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
yield remote.name
|
||||
|
||||
|
||||
class GenericRelationshipFetcher(AbstractFetcher):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.prop = self.path.property
|
||||
self.parent_dict = defaultdict(lambda: None)
|
||||
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
self.append_entity(entity)
|
||||
|
||||
def parent_key(self, entity):
|
||||
return (entity.__tablename__, getattr(entity, 'id'))
|
||||
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[self.parent_key(entity)] = entity
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
return (self.prop._discriminator_col.key, self.prop._id_col.key)
|
||||
|
||||
def populate(self):
|
||||
"""
|
||||
Populate batch fetched entities to parent objects.
|
||||
"""
|
||||
for entity in self.path.entities:
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[self.local_values(entity)]
|
||||
)
|
||||
|
||||
@property
|
||||
def related_entities(self):
|
||||
classes = []
|
||||
id_dict = defaultdict(list)
|
||||
for entity in self.path.entities:
|
||||
discriminator = getattr(entity, self.prop._discriminator_col.key)
|
||||
id_dict[discriminator].append(
|
||||
getattr(entity, self.prop._id_col.key)
|
||||
)
|
||||
return chain(*self._queries(sa.inspect(entity), id_dict))
|
||||
|
||||
def _queries(self, state, id_dict):
|
||||
for discriminator, ids in six.iteritems(id_dict):
|
||||
class_ = class_from_table_name(
|
||||
state, discriminator
|
||||
)
|
||||
yield self.path.session.query(
|
||||
class_
|
||||
).filter(
|
||||
class_.id.in_(ids)
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ManyToManyFetcher(Fetcher):
|
||||
@property
|
||||
def remote(self):
|
||||
return self.prop.secondary.c
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
for fk in remote.foreign_keys:
|
||||
if fk.column.table in self.prop.parent.tables:
|
||||
yield local.name
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
for fk in remote.foreign_keys:
|
||||
if fk.column.table in self.prop.parent.tables:
|
||||
yield remote.name
|
||||
|
||||
@property
|
||||
def relation_query_base(self):
|
||||
return (
|
||||
self.path.session
|
||||
.query(
|
||||
self.path.model,
|
||||
*[
|
||||
getattr(self.prop.secondary.c, name)
|
||||
for name in self.remote_column_names
|
||||
]
|
||||
)
|
||||
.join(
|
||||
self.prop.secondary, self.prop.secondaryjoin
|
||||
)
|
||||
)
|
||||
|
||||
def fetch(self):
|
||||
for value in self.related_entities:
|
||||
self.parent_dict[tuple(value[1:])].append(
|
||||
value[0]
|
||||
)
|
||||
|
||||
|
||||
class ManyToOneFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[self.parent_key(entity)] = entity
|
||||
|
||||
|
||||
class OneToManyFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[self.parent_key(entity)].append(
|
||||
entity
|
||||
)
|
173
sqlalchemy_utils/functions/orm.py
Normal file
173
sqlalchemy_utils/functions/orm.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def remove_property(class_, name):
|
||||
"""
|
||||
**Experimental function**
|
||||
|
||||
Remove property from declarative class
|
||||
"""
|
||||
mapper = class_.mapper
|
||||
table = class_.__table__
|
||||
columns = class_.mapper.c
|
||||
column = columns[name]
|
||||
del columns._data[name]
|
||||
del mapper.columns[name]
|
||||
columns._all_cols.remove(column)
|
||||
mapper._cols_by_table[table].remove(column)
|
||||
mapper.class_manager.uninstrument_attribute(name)
|
||||
del mapper._props[name]
|
||||
|
||||
|
||||
def primary_keys(class_):
|
||||
"""
|
||||
Returns all primary keys for given declarative class.
|
||||
"""
|
||||
for column in class_.__table__.c:
|
||||
if column.primary_key:
|
||||
yield column
|
||||
|
||||
|
||||
def table_name(obj):
|
||||
"""
|
||||
Return table name of given target, declarative class or the
|
||||
table name where the declarative attribute is bound to.
|
||||
"""
|
||||
class_ = getattr(obj, 'class_', obj)
|
||||
|
||||
try:
|
||||
return class_.__tablename__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return class_.__table__.name
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def declarative_base(model):
|
||||
"""
|
||||
Returns the declarative base for given model class.
|
||||
|
||||
:param model: SQLAlchemy declarative model
|
||||
"""
|
||||
for parent in model.__bases__:
|
||||
try:
|
||||
parent.metadata
|
||||
return declarative_base(parent)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
def has_changes(obj, attr):
|
||||
"""
|
||||
Simple shortcut function for checking if given attribute of given
|
||||
declarative model object has changed during the transaction.
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import has_changes
|
||||
|
||||
|
||||
user = User()
|
||||
|
||||
has_changes(user, 'name') # False
|
||||
|
||||
user.name = u'someone'
|
||||
|
||||
has_changes(user, 'name') # True
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param attr: Name of the attribute
|
||||
"""
|
||||
return (
|
||||
sa.inspect(obj)
|
||||
.attrs
|
||||
.get(attr)
|
||||
.history
|
||||
.has_changes()
|
||||
)
|
||||
|
||||
|
||||
def identity(obj):
|
||||
"""
|
||||
Return the identity of given sqlalchemy declarative model instance as a
|
||||
tuple. This differs from obj._sa_instance_state.identity in a way that it
|
||||
always returns the identity even if object is still in transient state (
|
||||
new object that is not yet persisted into database).
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy_utils import identity
|
||||
|
||||
|
||||
user = User(name=u'John Matrix')
|
||||
session.add(user)
|
||||
identity(user) # None
|
||||
inspect(user).identity # None
|
||||
|
||||
session.flush() # User now has id but is still in transient state
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # None
|
||||
|
||||
session.commit()
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # (1, )
|
||||
|
||||
|
||||
.. versionadded: 0.21.0
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
"""
|
||||
id_ = []
|
||||
for column in sa.inspect(obj.__class__).columns:
|
||||
if column.primary_key:
|
||||
id_.append(getattr(obj, column.name))
|
||||
|
||||
if all(value is None for value in id_):
|
||||
return None
|
||||
else:
|
||||
return tuple(id_)
|
||||
|
||||
|
||||
def naturally_equivalent(obj, obj2):
|
||||
"""
|
||||
Returns whether or not two given SQLAlchemy declarative instances are
|
||||
naturally equivalent (all their non primary key properties are equivalent).
|
||||
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import naturally_equivalent
|
||||
|
||||
|
||||
user = User(name=u'someone')
|
||||
user2 = User(name=u'someone')
|
||||
|
||||
user == user2 # False
|
||||
|
||||
naturally_equivalent(user, user2) # True
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param obj2: SQLAlchemy declarative model object to compare with `obj`
|
||||
"""
|
||||
for prop in sa.inspect(obj.__class__).iterate_properties:
|
||||
if not isinstance(prop, sa.orm.ColumnProperty):
|
||||
continue
|
||||
|
||||
if prop.columns[0].primary_key:
|
||||
continue
|
||||
|
||||
if not (getattr(obj, prop.key) == getattr(obj2, prop.key)):
|
||||
return False
|
||||
return True
|
Reference in New Issue
Block a user