diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 8073418..80abf99 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,9 +1,9 @@ from .aggregates import aggregated +from .batch_fetch import batch_fetch, with_backrefs from .decorators import generates from .eav import MetaValue, MetaType from .exceptions import ImproperlyConfigured from .functions import ( - batch_fetch, defer_except, escape_like, identity, @@ -15,7 +15,6 @@ from .functions import ( mock_engine, sort_query, table_name, - with_backrefs, database_exists, create_database, drop_database diff --git a/sqlalchemy_utils/batch_fetch.py b/sqlalchemy_utils/batch_fetch.py deleted file mode 100644 index a76a796..0000000 --- a/sqlalchemy_utils/batch_fetch.py +++ /dev/null @@ -1,398 +0,0 @@ -from collections import defaultdict -import six -import sqlalchemy as sa -from sqlalchemy.orm import RelationshipProperty -from sqlalchemy.orm.attributes import ( - set_committed_value, InstrumentedAttribute -) -from sqlalchemy.orm.session import object_session - - -class PathException(Exception): - pass - - -class with_backrefs(object): - """ - Marks given attribute path so that whenever its fetched with batch_fetch - the backref relations are force set too. Very useful when dealing with - certain many-to-many relationship scenarios. - """ - def __init__(self, path): - self.path = path - - -class Path(object): - """ - A class that represents an attribute path. - """ - def __init__(self, entities, prop, populate_backrefs=False): - self.property = prop - self.entities = entities - self.populate_backrefs = populate_backrefs - if not isinstance(self.property, RelationshipProperty): - raise PathException( - 'Given attribute is not a relationship property.' - ) - self.fetcher = self.fetcher_class(self) - - @property - def session(self): - return object_session(self.entities[0]) - - @property - def parent_model(self): - return self.entities[0].__class__ - - @property - def model(self): - return self.property.mapper.class_ - - @classmethod - def parse(cls, entities, path, populate_backrefs=False): - if isinstance(path, six.string_types): - attrs = path.split('.') - - if len(attrs) > 1: - related_entities = [] - for entity in entities: - related_entities.extend(getattr(entity, attrs[0])) - - if not related_entities: - return - subpath = '.'.join(attrs[1:]) - return Path.parse(related_entities, subpath, populate_backrefs) - else: - attr = getattr( - entities[0].__class__, attrs[0] - ) - elif isinstance(path, InstrumentedAttribute): - attr = path - else: - raise PathException('Unknown path type.') - - return Path(entities, attr.property, populate_backrefs) - - @property - def fetcher_class(self): - if self.property.secondary is not None: - return ManyToManyFetcher - else: - if self.property.direction.name == 'MANYTOONE': - return ManyToOneFetcher - else: - return OneToManyFetcher - - -class CompositePath(object): - def __init__(self, *paths): - self.paths = paths - - -def batch_fetch(entities, *attr_paths): - """ - Batch fetch given relationship attribute for collection of entities. - - This function is in many cases a valid alternative for SQLAlchemy's - subqueryload and performs lot better. - - :param entities: list of entities of the same type - :param attr_paths: - List of either InstrumentedAttribute objects or a strings representing - the name of the instrumented attribute - - Example:: - - - from sqlalchemy_utils import batch_fetch - - - users = session.query(User).limit(20).all() - - batch_fetch(users, User.phonenumbers) - - - Function also accepts strings as attribute names: :: - - - users = session.query(User).limit(20).all() - - batch_fetch(users, 'phonenumbers') - - - Multiple attributes may be provided: :: - - - clubs = session.query(Club).limit(20).all() - - batch_fetch( - clubs, - 'teams', - 'teams.players', - 'teams.players.user_groups' - ) - - You can also force populate backrefs: :: - - - from sqlalchemy_utils import with_backrefs - - - clubs = session.query(Club).limit(20).all() - - batch_fetch( - clubs, - 'teams', - 'teams.players', - with_backrefs('teams.players.user_groups') - ) - - """ - - if entities: - for path in attr_paths: - fetcher = fetcher_factory(entities, path) - if fetcher: - fetcher.fetch() - fetcher.populate() - - -def fetcher_factory(entities, path): - populate_backrefs = False - if isinstance(path, with_backrefs): - path = path.path - populate_backrefs = True - - if isinstance(path, CompositePath): - fetchers = [] - for path in path.paths: - path = Path.parse(entities, path, populate_backrefs) - if path: - fetchers.append( - path.fetcher - ) - - return CompositeFetcher(*fetchers) - else: - path = Path.parse(entities, path, populate_backrefs) - if path: - return path.fetcher - - -class CompositeFetcher(object): - def __init__(self, *fetchers): - if not all( - fetchers[0].path.model == fetcher.path.model - for fetcher in fetchers - ): - raise PathException( - 'Each relationship property must have the same class when ' - 'using CompositeFetcher.' - ) - self.fetchers = fetchers - - @property - def session(self): - return self.fetchers[0].path.session - - @property - def model(self): - return self.fetchers[0].path.model - - @property - def condition(self): - return sa.or_( - *[fetcher.condition for fetcher in self.fetchers] - ) - - @property - def related_entities(self): - return self.session.query(self.model).filter(self.condition) - - def fetch(self): - for entity in self.related_entities: - for fetcher in self.fetchers: - if any( - getattr(entity, name) - for name in fetcher.remote_column_names - ): - fetcher.append_entity(entity) - - def populate(self): - for fetcher in self.fetchers: - fetcher.populate() - - -class Fetcher(object): - def __init__(self, path): - self.path = path - self.prop = self.path.property - if self.prop.uselist: - self.parent_dict = defaultdict(list) - else: - self.parent_dict = defaultdict(lambda: None) - - @property - def local_values_list(self): - return [ - self.local_values(entity) - for entity in self.path.entities - ] - - @property - def relation_query_base(self): - return self.path.session.query(self.path.model) - - @property - def related_entities(self): - return self.relation_query_base.filter(self.condition) - - @property - def local_column_names(self): - return [local.name for local, remote in self.prop.local_remote_pairs] - - def parent_key(self, entity): - return tuple( - getattr(entity, name) - for name in self.remote_column_names - ) - - def local_values(self, entity): - return tuple( - getattr(entity, name) - for name in self.local_column_names - ) - - def populate_backrefs(self, related_entities): - """ - Populates backrefs for given related entities. - """ - backref_dict = dict( - (self.local_values(value[0]), []) - for value in related_entities - ) - for value in related_entities: - backref_dict[self.local_values(value[0])].append( - self.path.session.query(self.path.parent_model).get( - tuple(value[1:]) - ) - ) - for value in related_entities: - set_committed_value( - value[0], - self.prop.back_populates, - backref_dict[self.local_values(value[0])] - ) - - def populate(self): - """ - Populate batch fetched entities to parent objects. - """ - for entity in self.path.entities: - set_committed_value( - entity, - self.prop.key, - self.parent_dict[self.local_values(entity)] - ) - - if self.path.populate_backrefs: - self.populate_backrefs(self.related_entities) - - @property - def remote(self): - return self.path.model - - @property - def condition(self): - names = self.remote_column_names - if len(names) == 1: - return getattr(self.remote, names[0]).in_( - value[0] for value in self.local_values_list - ) - elif len(names) > 1: - conditions = [] - for entity in self.path.entities: - conditions.append( - sa.and_( - *[ - getattr(self.remote, remote.name) - == - getattr(entity, local.name) - for local, remote in self.prop.local_remote_pairs - if remote in self.remote_column_names - ] - ) - ) - return sa.or_(*conditions) - else: - raise PathException( - 'Could not obtain remote column names.' - ) - - def fetch(self): - for entity in self.related_entities: - self.append_entity(entity) - - @property - def remote_column_names(self): - return [remote.name for local, remote in self.prop.local_remote_pairs] - - -class ManyToManyFetcher(Fetcher): - @property - def remote(self): - return self.prop.secondary.c - - @property - def local_column_names(self): - names = [] - for local, remote in self.prop.local_remote_pairs: - for fk in remote.foreign_keys: - if fk.column.table in self.prop.parent.tables: - names.append(local.name) - return names - - @property - def remote_column_names(self): - names = [] - for local, remote in self.prop.local_remote_pairs: - for fk in remote.foreign_keys: - if fk.column.table in self.prop.parent.tables: - names.append(remote.name) - return names - - @property - def relation_query_base(self): - return ( - self.path.session - .query( - self.path.model, - *[ - getattr(self.prop.secondary.c, name) - for name in self.remote_column_names - ] - ) - .join( - self.prop.secondary, self.prop.secondaryjoin - ) - ) - - def fetch(self): - for value in self.related_entities: - self.parent_dict[tuple(value[1:])].append( - value[0] - ) - - -class ManyToOneFetcher(Fetcher): - def append_entity(self, entity): - #print 'appending entity ', entity, ' to key ', self.parent_key(entity) - self.parent_dict[self.parent_key(entity)] = entity - - -class OneToManyFetcher(Fetcher): - def append_entity(self, entity): - #print 'appending entity ', entity, ' to key ', self.parent_key(entity) - self.parent_dict[self.parent_key(entity)].append( - entity - ) diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 9cce2a0..0831912 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -1,292 +1,48 @@ -from collections import defaultdict -import sqlalchemy as sa -from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint -from .batch_fetch import batch_fetch, with_backrefs, CompositePath from .defer_except import defer_except from .mock import create_mock_engine, mock_engine from .render import render_expression, render_statement from .sort_query import sort_query, QuerySorterException -from .database import database_exists, create_database, drop_database - +from .database import ( + database_exists, + create_database, + drop_database, + escape_like, + is_auto_assigned_date_column, + is_indexed_foreign_key, + non_indexed_foreign_keys, +) +from .orm import ( + primary_keys, + table_name, + declarative_base, + has_changes, + identity, + naturally_equivalent, + remove_property +) __all__ = ( - batch_fetch, create_mock_engine, defer_except, mock_engine, sort_query, render_expression, render_statement, - with_backrefs, - CompositePath, QuerySorterException, database_exists, create_database, - drop_database + drop_database, + escape_like, + is_auto_assigned_date_column, + is_indexed_foreign_key, + non_indexed_foreign_keys, + remove_property, + primary_keys, + table_name, + declarative_base, + has_changes, + identity, + naturally_equivalent, ) -def escape_like(string, escape_char='*'): - """ - Escapes the string paremeter used in SQL LIKE expressions - - >>> from sqlalchemy_utils import escape_like - >>> query = session.query(User).filter( - ... User.name.ilike(escape_like('John')) - ... ) - - - :param string: a string to escape - :param escape_char: escape character - """ - return ( - string - .replace(escape_char, escape_char * 2) - .replace('%', escape_char + '%') - .replace('_', escape_char + '_') - ) - - -def remove_property(class_, name): - """ - **Experimental function** - - Remove property from declarative class - """ - mapper = class_.mapper - table = class_.__table__ - columns = class_.mapper.c - column = columns[name] - del columns._data[name] - del mapper.columns[name] - columns._all_cols.remove(column) - mapper._cols_by_table[table].remove(column) - mapper.class_manager.uninstrument_attribute(name) - del mapper._props[name] - - -def primary_keys(class_): - """ - Returns all primary keys for given declarative class. - """ - for column in class_.__table__.c: - if column.primary_key: - yield column - - -def table_name(obj): - """ - Return table name of given target, declarative class or the - table name where the declarative attribute is bound to. - """ - class_ = getattr(obj, 'class_', obj) - - try: - return class_.__tablename__ - except AttributeError: - pass - - try: - return class_.__table__.name - except AttributeError: - pass - - -def non_indexed_foreign_keys(metadata, engine=None): - """ - Finds all non indexed foreign keys from all tables of given MetaData. - - Very useful for optimizing postgresql database and finding out which - foreign keys need indexes. - - :param metadata: MetaData object to inspect tables from - """ - reflected_metadata = MetaData() - - if metadata.bind is None and engine is None: - raise Exception( - 'Either pass a metadata object with bind or ' - 'pass engine as a second parameter' - ) - - constraints = defaultdict(list) - - for table_name in metadata.tables.keys(): - table = Table( - table_name, - reflected_metadata, - autoload=True, - autoload_with=metadata.bind or engine - ) - - for constraint in table.constraints: - if not isinstance(constraint, ForeignKeyConstraint): - continue - - if not is_indexed_foreign_key(constraint): - constraints[table.name].append(constraint) - - return dict(constraints) - - -def is_indexed_foreign_key(constraint): - """ - Whether or not given foreign key constraint's columns have been indexed. - - :param constraint: ForeignKeyConstraint object to check the indexes - """ - for index in constraint.table.indexes: - index_column_names = set([ - column.name for column in index.columns - ]) - if index_column_names == set(constraint.columns): - return True - return False - - -def declarative_base(model): - """ - Returns the declarative base for given model class. - - :param model: SQLAlchemy declarative model - """ - for parent in model.__bases__: - try: - parent.metadata - return declarative_base(parent) - except AttributeError: - pass - return model - - -def is_auto_assigned_date_column(column): - """ - Returns whether or not given SQLAlchemy Column object's is auto assigned - DateTime or Date. - - :param column: SQLAlchemy Column object - """ - return ( - ( - isinstance(column.type, sa.DateTime) or - isinstance(column.type, sa.Date) - ) - and - ( - column.default or - column.server_default or - column.onupdate or - column.server_onupdate - ) - ) - - -def has_changes(obj, attr): - """ - Simple shortcut function for checking if given attribute of given - declarative model object has changed during the transaction. - - - :: - - - from sqlalchemy_utils import has_changes - - - user = User() - - has_changes(user, 'name') # False - - user.name = u'someone' - - has_changes(user, 'name') # True - - - :param obj: SQLAlchemy declarative model object - :param attr: Name of the attribute - """ - return ( - sa.inspect(obj) - .attrs - .get(attr) - .history - .has_changes() - ) - - -def identity(obj): - """ - Return the identity of given sqlalchemy declarative model instance as a - tuple. This differs from obj._sa_instance_state.identity in a way that it - always returns the identity even if object is still in transient state ( - new object that is not yet persisted into database). - - :: - - from sqlalchemy import inspect - from sqlalchemy_utils import identity - - - user = User(name=u'John Matrix') - session.add(user) - identity(user) # None - inspect(user).identity # None - - session.flush() # User now has id but is still in transient state - - identity(user) # (1,) - inspect(user).identity # None - - session.commit() - - identity(user) # (1,) - inspect(user).identity # (1, ) - - - .. versionadded: 0.21.0 - - :param obj: SQLAlchemy declarative model object - """ - id_ = [] - for column in sa.inspect(obj.__class__).columns: - if column.primary_key: - id_.append(getattr(obj, column.name)) - - if all(value is None for value in id_): - return None - else: - return tuple(id_) - - -def naturally_equivalent(obj, obj2): - """ - Returns whether or not two given SQLAlchemy declarative instances are - naturally equivalent (all their non primary key properties are equivalent). - - - :: - - from sqlalchemy_utils import naturally_equivalent - - - user = User(name=u'someone') - user2 = User(name=u'someone') - - user == user2 # False - - naturally_equivalent(user, user2) # True - - - :param obj: SQLAlchemy declarative model object - :param obj2: SQLAlchemy declarative model object to compare with `obj` - """ - for prop in sa.inspect(obj.__class__).iterate_properties: - if not isinstance(prop, sa.orm.ColumnProperty): - continue - - if prop.columns[0].primary_key: - continue - - if not (getattr(obj, prop.key) == getattr(obj2, prop.key)): - return False - return True diff --git a/sqlalchemy_utils/functions/database.py b/sqlalchemy_utils/functions/database.py index 450ab1c..95b38cf 100644 --- a/sqlalchemy_utils/functions/database.py +++ b/sqlalchemy_utils/functions/database.py @@ -1,10 +1,55 @@ +from collections import defaultdict from sqlalchemy.engine.url import make_url import sqlalchemy as sa +from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint from sqlalchemy.exc import ProgrammingError, OperationalError import os from copy import copy +def escape_like(string, escape_char='*'): + """ + Escapes the string paremeter used in SQL LIKE expressions + + >>> from sqlalchemy_utils import escape_like + >>> query = session.query(User).filter( + ... User.name.ilike(escape_like('John')) + ... ) + + + :param string: a string to escape + :param escape_char: escape character + """ + return ( + string + .replace(escape_char, escape_char * 2) + .replace('%', escape_char + '%') + .replace('_', escape_char + '_') + ) + + +def is_auto_assigned_date_column(column): + """ + Returns whether or not given SQLAlchemy Column object's is auto assigned + DateTime or Date. + + :param column: SQLAlchemy Column object + """ + return ( + ( + isinstance(column.type, sa.DateTime) or + isinstance(column.type, sa.Date) + ) + and + ( + column.default or + column.server_default or + column.onupdate or + column.server_onupdate + ) + ) + + def database_exists(url): """Check if a database exists. @@ -137,3 +182,55 @@ def drop_database(url): else: text = "DROP DATABASE %s" % database engine.execute(text) + + +def non_indexed_foreign_keys(metadata, engine=None): + """ + Finds all non indexed foreign keys from all tables of given MetaData. + + Very useful for optimizing postgresql database and finding out which + foreign keys need indexes. + + :param metadata: MetaData object to inspect tables from + """ + reflected_metadata = MetaData() + + if metadata.bind is None and engine is None: + raise Exception( + 'Either pass a metadata object with bind or ' + 'pass engine as a second parameter' + ) + + constraints = defaultdict(list) + + for table_name in metadata.tables.keys(): + table = Table( + table_name, + reflected_metadata, + autoload=True, + autoload_with=metadata.bind or engine + ) + + for constraint in table.constraints: + if not isinstance(constraint, ForeignKeyConstraint): + continue + + if not is_indexed_foreign_key(constraint): + constraints[table.name].append(constraint) + + return dict(constraints) + + +def is_indexed_foreign_key(constraint): + """ + Whether or not given foreign key constraint's columns have been indexed. + + :param constraint: ForeignKeyConstraint object to check the indexes + """ + for index in constraint.table.indexes: + index_column_names = set([ + column.name for column in index.columns + ]) + if index_column_names == set(constraint.columns): + return True + return False diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index 676cdb5..f2855d2 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -3,11 +3,18 @@ from sqlalchemy.orm.session import _state_session from sqlalchemy.orm import attributes, class_mapper from sqlalchemy.util import set_creation_order from sqlalchemy import exc as sa_exc -from .functions import table_name +from sqlalchemy_utils.functions import table_name + + +def class_from_table_name(state, table): + for class_ in state.class_._decl_class_registry.values(): + name = table_name(class_) + if name and name == table: + return class_ + return None class GenericAttributeImpl(attributes.ScalarAttributeImpl): - def get(self, state, dict_, passive=attributes.PASSIVE_OFF): if self.key in dict_: return dict_[self.key] @@ -22,11 +29,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): # Find class for discriminator. # TODO: Perhaps optimize with some sort of lookup? discriminator = state.attrs[self.parent_token.discriminator.key].value - target_class = None - for class_ in state.class_._decl_class_registry.values(): - name = table_name(class_) - if name and name == discriminator: - target_class = class_ + target_class = class_from_table_name(state, discriminator) if target_class is None: # Unknown discriminator; return nothing. @@ -96,13 +99,13 @@ class GenericRelationshipProperty(MapperProperty): class Comparator(PropComparator): def __init__(self, prop, parentmapper): - self.prop = prop + self.property = prop self._parentmapper = parentmapper def __eq__(self, other): discriminator = table_name(other) - q = self.prop._discriminator_col == discriminator - q &= self.prop._id_col == other.id + q = self.property._discriminator_col == discriminator + q &= self.property._id_col == other.id return q def __ne__(self, other): @@ -110,7 +113,7 @@ class GenericRelationshipProperty(MapperProperty): def is_type(self, other): discriminator = table_name(other) - return self.prop._discriminator_col == discriminator + return self.property._discriminator_col == discriminator def instrument_class(self, mapper): attributes.register_attribute( diff --git a/tests/batch_fetch/test_compound_fetching.py b/tests/batch_fetch/test_compound_fetching.py index 977a330..7f74a3c 100644 --- a/tests/batch_fetch/test_compound_fetching.py +++ b/tests/batch_fetch/test_compound_fetching.py @@ -1,6 +1,6 @@ import sqlalchemy as sa from sqlalchemy_utils import batch_fetch -from sqlalchemy_utils.functions import CompositePath +from sqlalchemy_utils.batch import CompositePath from tests import TestCase