diff --git a/sqlalchemy_utils/batch.py b/sqlalchemy_utils/batch.py new file mode 100644 index 0000000..a6c57d3 --- /dev/null +++ b/sqlalchemy_utils/batch.py @@ -0,0 +1,458 @@ +from collections import defaultdict +from itertools import chain +import six +import sqlalchemy as sa +from sqlalchemy.orm import RelationshipProperty +from sqlalchemy.orm.attributes import ( + set_committed_value, InstrumentedAttribute +) +from sqlalchemy.orm.session import object_session +from sqlalchemy_utils.generic import ( + GenericRelationshipProperty, class_from_table_name +) + + +class PathException(Exception): + pass + + +class with_backrefs(object): + """ + Marks given attribute path so that whenever its fetched with batch_fetch + the backref relations are force set too. Very useful when dealing with + certain many-to-many relationship scenarios. + """ + def __init__(self, path): + self.path = path + + +class Path(object): + """ + A class that represents an attribute path. + """ + def __init__(self, entities, prop, populate_backrefs=False): + self.property = prop + self.entities = entities + self.populate_backrefs = populate_backrefs + if (not isinstance(self.property, RelationshipProperty) and + not isinstance(self.property, GenericRelationshipProperty)): + raise PathException( + 'Given attribute is not a relationship property.' + ) + self.fetcher = self.fetcher_class(self) + + @property + def session(self): + return object_session(self.entities[0]) + + @property + def parent_model(self): + return self.entities[0].__class__ + + @property + def model(self): + return self.property.mapper.class_ + + @classmethod + def parse(cls, entities, path, populate_backrefs=False): + if isinstance(path, six.string_types): + attrs = path.split('.') + + if len(attrs) > 1: + related_entities = [] + for entity in entities: + related_entities.extend(getattr(entity, attrs[0])) + + if not related_entities: + return + subpath = '.'.join(attrs[1:]) + return Path.parse(related_entities, subpath, populate_backrefs) + else: + attr = getattr( + entities[0].__class__, attrs[0] + ) + elif isinstance(path, InstrumentedAttribute): + attr = path + else: + raise PathException('Unknown path type.') + + return Path(entities, attr.property, populate_backrefs) + + @property + def fetcher_class(self): + if isinstance(self.property, GenericRelationshipProperty): + return GenericRelationshipFetcher + else: + if self.property.secondary is not None: + return ManyToManyFetcher + else: + if self.property.direction.name == 'MANYTOONE': + return ManyToOneFetcher + else: + return OneToManyFetcher + + +class CompositePath(object): + def __init__(self, *paths): + self.paths = paths + + +def batch_fetch(entities, *attr_paths): + """ + Batch fetch given relationship attribute for collection of entities. + + This function is in many cases a valid alternative for SQLAlchemy's + subqueryload and performs lot better. + + :param entities: list of entities of the same type + :param attr_paths: + List of either InstrumentedAttribute objects or a strings representing + the name of the instrumented attribute + + Example:: + + + from sqlalchemy_utils import batch_fetch + + + users = session.query(User).limit(20).all() + + batch_fetch(users, User.phonenumbers) + + + Function also accepts strings as attribute names: :: + + + users = session.query(User).limit(20).all() + + batch_fetch(users, 'phonenumbers') + + + Multiple attributes may be provided: :: + + + clubs = session.query(Club).limit(20).all() + + batch_fetch( + clubs, + 'teams', + 'teams.players', + 'teams.players.user_groups' + ) + + You can also force populate backrefs: :: + + + from sqlalchemy_utils import with_backrefs + + + clubs = session.query(Club).limit(20).all() + + batch_fetch( + clubs, + 'teams', + 'teams.players', + with_backrefs('teams.players.user_groups') + ) + + """ + + if entities: + for path in attr_paths: + fetcher = fetcher_factory(entities, path) + if fetcher: + fetcher.fetch() + fetcher.populate() + + +def fetcher_factory(entities, path): + populate_backrefs = False + if isinstance(path, with_backrefs): + path = path.path + populate_backrefs = True + + if isinstance(path, CompositePath): + fetchers = [] + for path in path.paths: + path = Path.parse(entities, path, populate_backrefs) + if path: + fetchers.append( + path.fetcher + ) + + return CompositeFetcher(*fetchers) + else: + path = Path.parse(entities, path, populate_backrefs) + if path: + return path.fetcher + + +class CompositeFetcher(object): + def __init__(self, *fetchers): + if not all( + fetchers[0].path.model == fetcher.path.model + for fetcher in fetchers + ): + raise PathException( + 'Each relationship property must have the same class when ' + 'using CompositeFetcher.' + ) + self.fetchers = fetchers + + @property + def session(self): + return self.fetchers[0].path.session + + @property + def model(self): + return self.fetchers[0].path.model + + @property + def condition(self): + return sa.or_( + *[fetcher.condition for fetcher in self.fetchers] + ) + + @property + def related_entities(self): + return self.session.query(self.model).filter(self.condition) + + def fetch(self): + for entity in self.related_entities: + for fetcher in self.fetchers: + if any( + getattr(entity, name) + for name in fetcher.remote_column_names + ): + fetcher.append_entity(entity) + + def populate(self): + for fetcher in self.fetchers: + fetcher.populate() + + +class AbstractFetcher(object): + @property + def local_values_list(self): + return [ + self.local_values(entity) + for entity in self.path.entities + ] + + def local_values(self, entity): + return tuple( + getattr(entity, name) + for name in self.local_column_names + ) + + +class Fetcher(AbstractFetcher): + def __init__(self, path): + self.path = path + self.prop = self.path.property + if self.prop.uselist: + self.parent_dict = defaultdict(list) + else: + self.parent_dict = defaultdict(lambda: None) + + def parent_key(self, entity): + return tuple( + getattr(entity, name) + for name in self.remote_column_names + ) + + @property + def relation_query_base(self): + return self.path.session.query(self.path.model) + + @property + def related_entities(self): + return self.relation_query_base.filter(self.condition) + + @property + def local_column_names(self): + return [local.name for local, remote in self.prop.local_remote_pairs] + + def populate_backrefs(self, related_entities): + """ + Populates backrefs for given related entities. + """ + backref_dict = dict( + (self.local_values(value[0]), []) + for value in related_entities + ) + for value in related_entities: + backref_dict[self.local_values(value[0])].append( + self.path.session.query(self.path.parent_model).get( + tuple(value[1:]) + ) + ) + for value in related_entities: + set_committed_value( + value[0], + self.prop.back_populates, + backref_dict[self.local_values(value[0])] + ) + + def populate(self): + """ + Populate batch fetched entities to parent objects. + """ + for entity in self.path.entities: + set_committed_value( + entity, + self.prop.key, + self.parent_dict[self.local_values(entity)] + ) + + if self.path.populate_backrefs: + self.populate_backrefs(self.related_entities) + + @property + def remote(self): + return self.path.model + + @property + def condition(self): + names = list(self.remote_column_names) + if len(names) == 1: + return getattr(self.remote, names[0]).in_( + value[0] for value in self.local_values_list + ) + elif len(names) > 1: + conditions = [] + for entity in self.path.entities: + conditions.append( + sa.and_( + *[ + getattr(self.remote, remote.name) + == + getattr(entity, local.name) + for local, remote in self.prop.local_remote_pairs + if remote in self.remote_column_names + ] + ) + ) + return sa.or_(*conditions) + else: + raise PathException( + 'Could not obtain remote column names.' + ) + + def fetch(self): + for entity in self.related_entities: + self.append_entity(entity) + + @property + def remote_column_names(self): + for local, remote in self.prop.local_remote_pairs: + yield remote.name + + +class GenericRelationshipFetcher(AbstractFetcher): + def __init__(self, path): + self.path = path + self.prop = self.path.property + self.parent_dict = defaultdict(lambda: None) + + def fetch(self): + for entity in self.related_entities: + self.append_entity(entity) + + def parent_key(self, entity): + return (entity.__tablename__, getattr(entity, 'id')) + + def append_entity(self, entity): + self.parent_dict[self.parent_key(entity)] = entity + + @property + def local_column_names(self): + return (self.prop._discriminator_col.key, self.prop._id_col.key) + + def populate(self): + """ + Populate batch fetched entities to parent objects. + """ + for entity in self.path.entities: + set_committed_value( + entity, + self.prop.key, + self.parent_dict[self.local_values(entity)] + ) + + @property + def related_entities(self): + classes = [] + id_dict = defaultdict(list) + for entity in self.path.entities: + discriminator = getattr(entity, self.prop._discriminator_col.key) + id_dict[discriminator].append( + getattr(entity, self.prop._id_col.key) + ) + return chain(*self._queries(sa.inspect(entity), id_dict)) + + def _queries(self, state, id_dict): + for discriminator, ids in six.iteritems(id_dict): + class_ = class_from_table_name( + state, discriminator + ) + yield self.path.session.query( + class_ + ).filter( + class_.id.in_(ids) + ) + + + +class ManyToManyFetcher(Fetcher): + @property + def remote(self): + return self.prop.secondary.c + + @property + def local_column_names(self): + for local, remote in self.prop.local_remote_pairs: + for fk in remote.foreign_keys: + if fk.column.table in self.prop.parent.tables: + yield local.name + + @property + def remote_column_names(self): + for local, remote in self.prop.local_remote_pairs: + for fk in remote.foreign_keys: + if fk.column.table in self.prop.parent.tables: + yield remote.name + + @property + def relation_query_base(self): + return ( + self.path.session + .query( + self.path.model, + *[ + getattr(self.prop.secondary.c, name) + for name in self.remote_column_names + ] + ) + .join( + self.prop.secondary, self.prop.secondaryjoin + ) + ) + + def fetch(self): + for value in self.related_entities: + self.parent_dict[tuple(value[1:])].append( + value[0] + ) + + +class ManyToOneFetcher(Fetcher): + def append_entity(self, entity): + self.parent_dict[self.parent_key(entity)] = entity + + +class OneToManyFetcher(Fetcher): + def append_entity(self, entity): + self.parent_dict[self.parent_key(entity)].append( + entity + ) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py new file mode 100644 index 0000000..a6c66d0 --- /dev/null +++ b/sqlalchemy_utils/functions/orm.py @@ -0,0 +1,173 @@ +import sqlalchemy as sa +from collections import defaultdict + + +def remove_property(class_, name): + """ + **Experimental function** + + Remove property from declarative class + """ + mapper = class_.mapper + table = class_.__table__ + columns = class_.mapper.c + column = columns[name] + del columns._data[name] + del mapper.columns[name] + columns._all_cols.remove(column) + mapper._cols_by_table[table].remove(column) + mapper.class_manager.uninstrument_attribute(name) + del mapper._props[name] + + +def primary_keys(class_): + """ + Returns all primary keys for given declarative class. + """ + for column in class_.__table__.c: + if column.primary_key: + yield column + + +def table_name(obj): + """ + Return table name of given target, declarative class or the + table name where the declarative attribute is bound to. + """ + class_ = getattr(obj, 'class_', obj) + + try: + return class_.__tablename__ + except AttributeError: + pass + + try: + return class_.__table__.name + except AttributeError: + pass + + +def declarative_base(model): + """ + Returns the declarative base for given model class. + + :param model: SQLAlchemy declarative model + """ + for parent in model.__bases__: + try: + parent.metadata + return declarative_base(parent) + except AttributeError: + pass + return model + + +def has_changes(obj, attr): + """ + Simple shortcut function for checking if given attribute of given + declarative model object has changed during the transaction. + + + :: + + + from sqlalchemy_utils import has_changes + + + user = User() + + has_changes(user, 'name') # False + + user.name = u'someone' + + has_changes(user, 'name') # True + + + :param obj: SQLAlchemy declarative model object + :param attr: Name of the attribute + """ + return ( + sa.inspect(obj) + .attrs + .get(attr) + .history + .has_changes() + ) + + +def identity(obj): + """ + Return the identity of given sqlalchemy declarative model instance as a + tuple. This differs from obj._sa_instance_state.identity in a way that it + always returns the identity even if object is still in transient state ( + new object that is not yet persisted into database). + + :: + + from sqlalchemy import inspect + from sqlalchemy_utils import identity + + + user = User(name=u'John Matrix') + session.add(user) + identity(user) # None + inspect(user).identity # None + + session.flush() # User now has id but is still in transient state + + identity(user) # (1,) + inspect(user).identity # None + + session.commit() + + identity(user) # (1,) + inspect(user).identity # (1, ) + + + .. versionadded: 0.21.0 + + :param obj: SQLAlchemy declarative model object + """ + id_ = [] + for column in sa.inspect(obj.__class__).columns: + if column.primary_key: + id_.append(getattr(obj, column.name)) + + if all(value is None for value in id_): + return None + else: + return tuple(id_) + + +def naturally_equivalent(obj, obj2): + """ + Returns whether or not two given SQLAlchemy declarative instances are + naturally equivalent (all their non primary key properties are equivalent). + + + :: + + from sqlalchemy_utils import naturally_equivalent + + + user = User(name=u'someone') + user2 = User(name=u'someone') + + user == user2 # False + + naturally_equivalent(user, user2) # True + + + :param obj: SQLAlchemy declarative model object + :param obj2: SQLAlchemy declarative model object to compare with `obj` + """ + for prop in sa.inspect(obj.__class__).iterate_properties: + if not isinstance(prop, sa.orm.ColumnProperty): + continue + + if prop.columns[0].primary_key: + continue + + if not (getattr(obj, prop.key) == getattr(obj2, prop.key)): + return False + return True