931 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			931 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| try:
 | |
|     from collections import OrderedDict
 | |
| except ImportError:
 | |
|     from ordereddict import OrderedDict
 | |
| 
 | |
| from functools import partial
 | |
| from inspect import isclass
 | |
| from operator import attrgetter
 | |
| 
 | |
| import six
 | |
| import sqlalchemy as sa
 | |
| from sqlalchemy.engine.interfaces import Dialect
 | |
| from sqlalchemy.ext.hybrid import hybrid_property
 | |
| from sqlalchemy.orm import mapperlib
 | |
| from sqlalchemy.orm.attributes import InstrumentedAttribute
 | |
| from sqlalchemy.orm.exc import UnmappedInstanceError
 | |
| from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty
 | |
| from sqlalchemy.orm.query import _ColumnEntity
 | |
| from sqlalchemy.orm.session import object_session
 | |
| from sqlalchemy.orm.util import AliasedInsp
 | |
| 
 | |
| from sqlalchemy_utils.utils import is_sequence
 | |
| 
 | |
| 
 | |
| def get_class_by_table(base, table, data=None):
 | |
|     """
 | |
|     Return declarative class associated with given table. If no class is found
 | |
|     this function returns `None`. If multiple classes were found (polymorphic
 | |
|     cases) additional `data` parameter can be given to hint which class
 | |
|     to return.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         class User(Base):
 | |
|             __tablename__ = 'entity'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.String)
 | |
| 
 | |
| 
 | |
|         get_class_by_table(Base, User.__table__)  # User class
 | |
| 
 | |
| 
 | |
|     This function also supports models using single table inheritance.
 | |
|     Additional data paratemer should be provided in these case.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         class Entity(Base):
 | |
|             __tablename__ = 'entity'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.String)
 | |
|             type = sa.Column(sa.String)
 | |
|             __mapper_args__ = {
 | |
|                 'polymorphic_on': type,
 | |
|                 'polymorphic_identity': 'entity'
 | |
|             }
 | |
| 
 | |
|         class User(Entity):
 | |
|             __mapper_args__ = {
 | |
|                 'polymorphic_identity': 'user'
 | |
|             }
 | |
| 
 | |
| 
 | |
|         # Entity class
 | |
|         get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
 | |
| 
 | |
|         # User class
 | |
|         get_class_by_table(Base, Entity.__table__, {'type': 'user'})
 | |
| 
 | |
| 
 | |
|     :param base: Declarative model base
 | |
|     :param table: SQLAlchemy Table object
 | |
|     :param data: Data row to determine the class in polymorphic scenarios
 | |
|     :return: Declarative class or None.
 | |
|     """
 | |
|     found_classes = set(
 | |
|         c for c in base._decl_class_registry.values()
 | |
|         if hasattr(c, '__table__') and c.__table__ is table
 | |
|     )
 | |
|     if len(found_classes) > 1:
 | |
|         if not data:
 | |
|             raise ValueError(
 | |
|                 "Multiple declarative classes found for table '{0}'. "
 | |
|                 "Please provide data parameter for this function to be able "
 | |
|                 "to determine polymorphic scenarios.".format(
 | |
|                     table.name
 | |
|                 )
 | |
|             )
 | |
|         else:
 | |
|             for cls in found_classes:
 | |
|                 mapper = sa.inspect(cls)
 | |
|                 polymorphic_on = mapper.polymorphic_on.name
 | |
|                 if polymorphic_on in data:
 | |
|                     if data[polymorphic_on] == mapper.polymorphic_identity:
 | |
|                         return cls
 | |
|             raise ValueError(
 | |
|                 "Multiple declarative classes found for table '{0}'. Given "
 | |
|                 "data row does not match any polymorphic identity of the "
 | |
|                 "found classes.".format(
 | |
|                     table.name
 | |
|                 )
 | |
|             )
 | |
|     elif found_classes:
 | |
|         return found_classes.pop()
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def get_type(expr):
 | |
|     """
 | |
|     Return the associated type with given Column, InstrumentedAttribute,
 | |
|     ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct.
 | |
| 
 | |
|     For constructs wrapping columns this is the column type. For relationships
 | |
|     this function returns the relationship mapper class.
 | |
| 
 | |
|     :param expr:
 | |
|         SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other
 | |
|         similar SA construct.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         class User(Base):
 | |
|             __tablename__ = 'user'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.String)
 | |
| 
 | |
| 
 | |
|         class Article(Base):
 | |
|             __tablename__ = 'article'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
 | |
|             author = sa.orm.relationship(User)
 | |
| 
 | |
| 
 | |
|         get_type(User.__table__.c.name)  # sa.String()
 | |
|         get_type(User.name)  # sa.String()
 | |
|         get_type(User.name.property)  # sa.String()
 | |
| 
 | |
|         get_type(Article.author)  # User
 | |
| 
 | |
| 
 | |
|     .. versionadded: 0.30.9
 | |
|     """
 | |
|     if hasattr(expr, 'type'):
 | |
|         return expr.type
 | |
|     elif isinstance(expr, InstrumentedAttribute):
 | |
|         expr = expr.property
 | |
| 
 | |
|     if isinstance(expr, ColumnProperty):
 | |
|         return expr.columns[0].type
 | |
|     elif isinstance(expr, RelationshipProperty):
 | |
|         return expr.mapper.class_
 | |
|     raise TypeError("Couldn't inspect type.")
 | |
| 
 | |
| 
 | |
| def get_column_key(model, column):
 | |
|     """
 | |
|     Return the key for given column in given model.
 | |
| 
 | |
|     :param model: SQLAlchemy declarative model object
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         class User(Base):
 | |
|             __tablename__ = 'user'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column('_name', sa.String)
 | |
| 
 | |
| 
 | |
|         get_column_key(User, User.__table__.c._name)  # 'name'
 | |
| 
 | |
|     .. versionadded: 0.26.5
 | |
| 
 | |
|     .. versionchanged: 0.27.11
 | |
|         Throws UnmappedColumnError instead of ValueError when no property was
 | |
|         found for given column. This is consistent with how SQLAlchemy works.
 | |
|     """
 | |
|     mapper = sa.inspect(model)
 | |
|     try:
 | |
|         return mapper.get_property_by_column(column).key
 | |
|     except sa.orm.exc.UnmappedColumnError:
 | |
|         for key, c in mapper.columns.items():
 | |
|             if c.name == column.name and c.table is column.table:
 | |
|                 return key
 | |
|     raise sa.orm.exc.UnmappedColumnError(
 | |
|         'No column %s is configured on mapper %s...' %
 | |
|         (column, mapper)
 | |
|     )
 | |
| 
 | |
| 
 | |
| def get_mapper(mixed):
 | |
|     """
 | |
|     Return related SQLAlchemy Mapper for given SQLAlchemy object.
 | |
| 
 | |
|     :param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         from sqlalchemy_utils import get_mapper
 | |
| 
 | |
| 
 | |
|         get_mapper(User)
 | |
| 
 | |
|         get_mapper(User())
 | |
| 
 | |
|         get_mapper(User.__table__)
 | |
| 
 | |
|         get_mapper(User.__mapper__)
 | |
| 
 | |
|         get_mapper(sa.orm.aliased(User))
 | |
| 
 | |
|         get_mapper(sa.orm.aliased(User.__table__))
 | |
| 
 | |
| 
 | |
|     Raises:
 | |
|         ValueError: if multiple mappers were found for given argument
 | |
| 
 | |
|     .. versionadded: 0.26.1
 | |
|     """
 | |
|     if isinstance(mixed, sa.orm.query._MapperEntity):
 | |
|         mixed = mixed.expr
 | |
|     elif isinstance(mixed, sa.Column):
 | |
|         mixed = mixed.table
 | |
|     elif isinstance(mixed, sa.orm.query._ColumnEntity):
 | |
|         mixed = mixed.expr
 | |
| 
 | |
|     if isinstance(mixed, sa.orm.Mapper):
 | |
|         return mixed
 | |
|     if isinstance(mixed, sa.orm.util.AliasedClass):
 | |
|         return sa.inspect(mixed).mapper
 | |
|     if isinstance(mixed, sa.sql.selectable.Alias):
 | |
|         mixed = mixed.element
 | |
|     if isinstance(mixed, AliasedInsp):
 | |
|         return mixed.mapper
 | |
|     if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
 | |
|         mixed = mixed.class_
 | |
|     if isinstance(mixed, sa.Table):
 | |
|         mappers = [
 | |
|             mapper for mapper in mapperlib._mapper_registry
 | |
|             if mixed in mapper.tables
 | |
|         ]
 | |
|         if len(mappers) > 1:
 | |
|             raise ValueError(
 | |
|                 "Multiple mappers found for table '%s'." % mixed.name
 | |
|             )
 | |
|         elif not mappers:
 | |
|             raise ValueError(
 | |
|                 "Could not get mapper for table '%s'." % mixed.name
 | |
|             )
 | |
|         else:
 | |
|             return mappers[0]
 | |
|     if not isclass(mixed):
 | |
|         mixed = type(mixed)
 | |
|     return sa.inspect(mixed)
 | |
| 
 | |
| 
 | |
| def get_bind(obj):
 | |
|     """
 | |
|     Return the bind for given SQLAlchemy Engine / Connection / declarative
 | |
|     model object.
 | |
| 
 | |
|     :param obj: SQLAlchemy Engine / Connection / declarative model object
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         from sqlalchemy_utils import get_bind
 | |
| 
 | |
| 
 | |
|         get_bind(session)  # Connection object
 | |
| 
 | |
|         get_bind(user)
 | |
| 
 | |
|     """
 | |
|     if hasattr(obj, 'bind'):
 | |
|         conn = obj.bind
 | |
|     else:
 | |
|         try:
 | |
|             conn = object_session(obj).bind
 | |
|         except UnmappedInstanceError:
 | |
|             conn = obj
 | |
| 
 | |
|     if not hasattr(conn, 'execute'):
 | |
|         raise TypeError(
 | |
|             'This method accepts only Session, Engine, Connection and '
 | |
|             'declarative model objects.'
 | |
|         )
 | |
|     return conn
 | |
| 
 | |
| 
 | |
| def get_primary_keys(mixed):
 | |
|     """
 | |
|     Return an OrderedDict of all primary keys for given Table object,
 | |
|     declarative class or declarative class instance.
 | |
| 
 | |
|     :param mixed:
 | |
|         SA Table object, SA declarative class or SA declarative class instance
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         get_primary_keys(User)
 | |
| 
 | |
|         get_primary_keys(User())
 | |
| 
 | |
|         get_primary_keys(User.__table__)
 | |
| 
 | |
|         get_primary_keys(User.__mapper__)
 | |
| 
 | |
|         get_primary_keys(sa.orm.aliased(User))
 | |
| 
 | |
|         get_primary_keys(sa.orm.aliased(User.__table__))
 | |
| 
 | |
| 
 | |
|     .. versionchanged: 0.25.3
 | |
|         Made the function return an ordered dictionary instead of generator.
 | |
|         This change was made to support primary key aliases.
 | |
| 
 | |
|         Renamed this function to 'get_primary_keys', formerly 'primary_keys'
 | |
| 
 | |
|     .. seealso:: :func:`get_columns`
 | |
|     """
 | |
|     return OrderedDict(
 | |
|         (
 | |
|             (key, column) for key, column in get_columns(mixed).items()
 | |
|             if column.primary_key
 | |
|         )
 | |
|     )
 | |
| 
 | |
| 
 | |
| def get_tables(mixed):
 | |
|     """
 | |
|     Return a set of tables associated with given SQLAlchemy object.
 | |
| 
 | |
|     Let's say we have three classes which use joined table inheritance
 | |
|     TextItem, Article and BlogPost. Article and BlogPost inherit TextItem.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         get_tables(Article)  # set([Table('article', ...), Table('text_item')])
 | |
| 
 | |
|         get_tables(Article())
 | |
| 
 | |
|         get_tables(Article.__mapper__)
 | |
| 
 | |
| 
 | |
|     If the TextItem entity is using with_polymorphic='*' then this function
 | |
|     returns all child tables (article and blog_post) as well.
 | |
| 
 | |
|     ::
 | |
| 
 | |
| 
 | |
|         get_tables(TextItem)  # set([Table('text_item', ...)], ...])
 | |
| 
 | |
| 
 | |
|     .. versionadded: 0.26.0
 | |
| 
 | |
|     :param mixed:
 | |
|         SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or
 | |
|         a SA Alias object wrapping any of these objects.
 | |
|     """
 | |
|     if isinstance(mixed, sa.Table):
 | |
|         return [mixed]
 | |
|     elif isinstance(mixed, sa.Column):
 | |
|         return [mixed.table]
 | |
|     elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
 | |
|         return mixed.parent.tables
 | |
|     elif isinstance(mixed, sa.orm.query._ColumnEntity):
 | |
|         mixed = mixed.expr
 | |
| 
 | |
|     mapper = get_mapper(mixed)
 | |
| 
 | |
|     polymorphic_mappers = get_polymorphic_mappers(mapper)
 | |
|     if polymorphic_mappers:
 | |
|         tables = sum((m.tables for m in polymorphic_mappers), [])
 | |
|     else:
 | |
|         tables = mapper.tables
 | |
|     return tables
 | |
| 
 | |
| 
 | |
| def get_columns(mixed):
 | |
|     """
 | |
|     Return a collection of all Column objects for given SQLAlchemy
 | |
|     object.
 | |
| 
 | |
|     The type of the collection depends on the type of the object to return the
 | |
|     columns from.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         get_columns(User)
 | |
| 
 | |
|         get_columns(User())
 | |
| 
 | |
|         get_columns(User.__table__)
 | |
| 
 | |
|         get_columns(User.__mapper__)
 | |
| 
 | |
|         get_columns(sa.orm.aliased(User))
 | |
| 
 | |
|         get_columns(sa.orm.alised(User.__table__))
 | |
| 
 | |
| 
 | |
|     :param mixed:
 | |
|         SA Table object, SA Mapper, SA declarative class, SA declarative class
 | |
|         instance or an alias of any of these objects
 | |
|     """
 | |
|     if isinstance(mixed, sa.Table):
 | |
|         return mixed.c
 | |
|     if isinstance(mixed, sa.orm.util.AliasedClass):
 | |
|         return sa.inspect(mixed).mapper.columns
 | |
|     if isinstance(mixed, sa.sql.selectable.Alias):
 | |
|         return mixed.c
 | |
|     if isinstance(mixed, sa.orm.Mapper):
 | |
|         return mixed.columns
 | |
|     if not isclass(mixed):
 | |
|         mixed = mixed.__class__
 | |
|     return sa.inspect(mixed).columns
 | |
| 
 | |
| 
 | |
| def table_name(obj):
 | |
|     """
 | |
|     Return table name of given target, declarative class or the
 | |
|     table name where the declarative attribute is bound to.
 | |
|     """
 | |
|     class_ = getattr(obj, 'class_', obj)
 | |
| 
 | |
|     try:
 | |
|         return class_.__tablename__
 | |
|     except AttributeError:
 | |
|         pass
 | |
| 
 | |
|     try:
 | |
|         return class_.__table__.name
 | |
|     except AttributeError:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| def getattrs(obj, attrs):
 | |
|     return map(partial(getattr, obj), attrs)
 | |
| 
 | |
| 
 | |
| def quote(mixed, ident):
 | |
|     """
 | |
|     Conditionally quote an identifier.
 | |
|     ::
 | |
| 
 | |
| 
 | |
|         from sqlalchemy_utils import quote
 | |
| 
 | |
| 
 | |
|         engine = create_engine('sqlite:///:memory:')
 | |
| 
 | |
|         quote(engine, 'order')
 | |
|         # '"order"'
 | |
| 
 | |
|         quote(engine, 'some_other_identifier')
 | |
|         # 'some_other_identifier'
 | |
| 
 | |
| 
 | |
|     :param mixed: SQLAlchemy Session / Connection / Engine / Dialect object.
 | |
|     :param ident: identifier to conditionally quote
 | |
|     """
 | |
|     if isinstance(mixed, Dialect):
 | |
|         dialect = mixed
 | |
|     else:
 | |
|         dialect = get_bind(mixed).dialect
 | |
|     return dialect.preparer(dialect).quote(ident)
 | |
| 
 | |
| 
 | |
| def query_labels(query):
 | |
|     """
 | |
|     Return all labels for given SQLAlchemy query object.
 | |
| 
 | |
|     Example::
 | |
| 
 | |
| 
 | |
|         query = session.query(
 | |
|             Category,
 | |
|             db.func.count(Article.id).label('articles')
 | |
|         )
 | |
| 
 | |
|         query_labels(query)  # ['articles']
 | |
| 
 | |
|     :param query: SQLAlchemy Query object
 | |
|     """
 | |
|     return [
 | |
|         entity._label_name for entity in query._entities
 | |
|         if isinstance(entity, _ColumnEntity) and entity._label_name
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def get_query_entities(query):
 | |
|     """
 | |
|     Return a list of all entities present in given SQLAlchemy query object.
 | |
| 
 | |
|     Examples::
 | |
| 
 | |
| 
 | |
|         from sqlalchemy_utils import get_query_entities
 | |
| 
 | |
| 
 | |
|         query = session.query(Category)
 | |
| 
 | |
|         get_query_entities(query)  # [<Category>]
 | |
| 
 | |
| 
 | |
|         query = session.query(Category.id)
 | |
| 
 | |
|         get_query_entities(query)  # [<Category>]
 | |
| 
 | |
| 
 | |
|     This function also supports queries with joins.
 | |
| 
 | |
|     ::
 | |
| 
 | |
| 
 | |
|         query = session.query(Category).join(Article)
 | |
| 
 | |
|         get_query_entities(query)  # [<Category>, <Article>]
 | |
| 
 | |
|     .. versionchanged: 0.26.7
 | |
|         This function now returns a list instead of generator
 | |
| 
 | |
|     :param query: SQLAlchemy Query object
 | |
|     """
 | |
|     exprs = [
 | |
|         d['expr']
 | |
|         if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column)
 | |
|         else d['entity']
 | |
|         for d in query.column_descriptions
 | |
|     ]
 | |
|     return [
 | |
|         get_query_entity(expr) for expr in exprs
 | |
|     ] + [
 | |
|         get_query_entity(entity) for entity in query._join_entities
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def is_labeled_query(expr):
 | |
|     return (
 | |
|         isinstance(expr, sa.sql.elements.Label) and
 | |
|         isinstance(
 | |
|             list(expr.base_columns)[0],
 | |
|             (sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect)
 | |
|         )
 | |
|     )
 | |
| 
 | |
| 
 | |
| def get_query_entity(expr):
 | |
|     if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
 | |
|         return expr.parent.class_
 | |
|     elif isinstance(expr, sa.Column):
 | |
|         return expr.table
 | |
|     elif isinstance(expr, AliasedInsp):
 | |
|         return expr.entity
 | |
|     return expr
 | |
| 
 | |
| 
 | |
| def get_query_entity_by_alias(query, alias):
 | |
|     entities = get_query_entities(query)
 | |
| 
 | |
|     if not alias:
 | |
|         return entities[0]
 | |
| 
 | |
|     for entity in entities:
 | |
|         if isinstance(entity, sa.orm.util.AliasedClass):
 | |
|             name = sa.inspect(entity).name
 | |
|         else:
 | |
|             name = get_mapper(entity).tables[0].name
 | |
| 
 | |
|         if name == alias:
 | |
|             return entity
 | |
| 
 | |
| 
 | |
| def get_polymorphic_mappers(mixed):
 | |
|     if isinstance(mixed, AliasedInsp):
 | |
|         return mixed.with_polymorphic_mappers
 | |
|     else:
 | |
|         return mixed.polymorphic_map.values()
 | |
| 
 | |
| 
 | |
| def get_query_descriptor(query, entity, attr):
 | |
|     if attr in query_labels(query):
 | |
|         return attr
 | |
|     else:
 | |
|         entity = get_query_entity_by_alias(query, entity)
 | |
|         if entity:
 | |
|             descriptor = get_descriptor(entity, attr)
 | |
|             if (
 | |
|                 hasattr(descriptor, 'property') and
 | |
|                 isinstance(descriptor.property, sa.orm.RelationshipProperty)
 | |
|             ):
 | |
|                 return
 | |
|             return descriptor
 | |
| 
 | |
| 
 | |
| def get_descriptor(entity, attr):
 | |
|     mapper = sa.inspect(entity)
 | |
| 
 | |
|     for key, descriptor in get_all_descriptors(mapper).items():
 | |
|         if attr == key:
 | |
|             prop = (
 | |
|                 descriptor.property
 | |
|                 if hasattr(descriptor, 'property')
 | |
|                 else None
 | |
|             )
 | |
|             if isinstance(prop, ColumnProperty):
 | |
|                 if isinstance(entity, sa.orm.util.AliasedClass):
 | |
|                     for c in mapper.selectable.c:
 | |
|                         if c.key == attr:
 | |
|                             return c
 | |
|                 else:
 | |
|                     # If the property belongs to a class that uses
 | |
|                     # polymorphic inheritance we have to take into account
 | |
|                     # situations where the attribute exists in child class
 | |
|                     # but not in parent class.
 | |
|                     return getattr(prop.parent.class_, attr)
 | |
|             else:
 | |
|                 # Handle synonyms, relationship properties and hybrid
 | |
|                 # properties
 | |
|                 try:
 | |
|                     return getattr(mapper.class_, attr)
 | |
|                 except AttributeError:
 | |
|                     pass
 | |
| 
 | |
| 
 | |
| def get_all_descriptors(expr):
 | |
|     insp = sa.inspect(expr)
 | |
|     polymorphic_mappers = get_polymorphic_mappers(insp)
 | |
|     if polymorphic_mappers:
 | |
|         attrs = dict(get_mapper(expr).all_orm_descriptors)
 | |
|         for submapper in polymorphic_mappers:
 | |
|             for key, descriptor in submapper.all_orm_descriptors.items():
 | |
|                 if key not in attrs:
 | |
|                     attrs[key] = descriptor
 | |
|         return attrs
 | |
|     return get_mapper(expr).all_orm_descriptors
 | |
| 
 | |
| 
 | |
| def get_hybrid_properties(model):
 | |
|     """
 | |
|     Returns a dictionary of hybrid property keys and hybrid properties for
 | |
|     given SQLAlchemy declarative model / mapper.
 | |
| 
 | |
| 
 | |
|     Consider the following model
 | |
| 
 | |
|     ::
 | |
| 
 | |
| 
 | |
|         from sqlalchemy.ext.hybrid import hybrid_property
 | |
| 
 | |
| 
 | |
|         class Category(Base):
 | |
|             __tablename__ = 'category'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.Unicode(255))
 | |
| 
 | |
|             @hybrid_property
 | |
|             def lowercase_name(self):
 | |
|                 return self.name.lower()
 | |
| 
 | |
|             @lowercase_name.expression
 | |
|             def lowercase_name(cls):
 | |
|                 return sa.func.lower(cls.name)
 | |
| 
 | |
| 
 | |
|     You can now easily get a list of all hybrid property names
 | |
| 
 | |
|     ::
 | |
| 
 | |
| 
 | |
|         from sqlalchemy_utils import get_hybrid_properties
 | |
| 
 | |
| 
 | |
|         get_hybrid_properties(Category).keys()  # ['lowercase_name']
 | |
| 
 | |
| 
 | |
|     .. versionchanged: 0.26.7
 | |
|         This function now returns a dictionary instead of generator
 | |
| 
 | |
|     :param model: SQLAlchemy declarative model or mapper
 | |
|     """
 | |
|     return dict(
 | |
|         (key, prop)
 | |
|         for key, prop in sa.inspect(model).all_orm_descriptors.items()
 | |
|         if isinstance(prop, hybrid_property)
 | |
|     )
 | |
| 
 | |
| 
 | |
| def get_declarative_base(model):
 | |
|     """
 | |
|     Returns the declarative base for given model class.
 | |
| 
 | |
|     :param model: SQLAlchemy declarative model
 | |
|     """
 | |
|     for parent in model.__bases__:
 | |
|         try:
 | |
|             parent.metadata
 | |
|             return get_declarative_base(parent)
 | |
|         except AttributeError:
 | |
|             pass
 | |
|     return model
 | |
| 
 | |
| 
 | |
| def getdotattr(obj_or_class, dot_path, condition=None):
 | |
|     """
 | |
|     Allow dot-notated strings to be passed to `getattr`.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         getdotattr(SubSection, 'section.document')
 | |
| 
 | |
|         getdotattr(subsection, 'section.document')
 | |
| 
 | |
| 
 | |
|     :param obj_or_class: Any object or class
 | |
|     :param dot_path: Attribute path with dot mark as separator
 | |
|     """
 | |
|     last = obj_or_class
 | |
| 
 | |
|     for path in str(dot_path).split('.'):
 | |
|         getter = attrgetter(path)
 | |
| 
 | |
|         if is_sequence(last):
 | |
|             tmp = []
 | |
|             for element in last:
 | |
|                 value = getter(element)
 | |
|                 if is_sequence(value):
 | |
|                     tmp.extend(value)
 | |
|                 else:
 | |
|                     tmp.append(value)
 | |
|             last = tmp
 | |
|         elif isinstance(last, InstrumentedAttribute):
 | |
|             last = getter(last.property.mapper.class_)
 | |
|         elif last is None:
 | |
|             return None
 | |
|         else:
 | |
|             last = getter(last)
 | |
|         if condition is not None:
 | |
|             if is_sequence(last):
 | |
|                 last = [v for v in last if condition(v)]
 | |
|             else:
 | |
|                 if not condition(last):
 | |
|                     return None
 | |
| 
 | |
|     return last
 | |
| 
 | |
| 
 | |
| def is_deleted(obj):
 | |
|     return obj in sa.orm.object_session(obj).deleted
 | |
| 
 | |
| 
 | |
| def has_changes(obj, attrs=None, exclude=None):
 | |
|     """
 | |
|     Simple shortcut function for checking if given attributes of given
 | |
|     declarative model object have changed during the session. Without
 | |
|     parameters this checks if given object has any modificiations. Additionally
 | |
|     exclude parameter can be given to check if given object has any changes
 | |
|     in any attributes other than the ones given in exclude.
 | |
| 
 | |
| 
 | |
|     ::
 | |
| 
 | |
| 
 | |
|         from sqlalchemy_utils import has_changes
 | |
| 
 | |
| 
 | |
|         user = User()
 | |
| 
 | |
|         has_changes(user, 'name')  # False
 | |
| 
 | |
|         user.name = u'someone'
 | |
| 
 | |
|         has_changes(user, 'name')  # True
 | |
| 
 | |
|         has_changes(user)  # True
 | |
| 
 | |
| 
 | |
|     You can check multiple attributes as well.
 | |
|     ::
 | |
| 
 | |
| 
 | |
|         has_changes(user, ['age'])  # True
 | |
| 
 | |
|         has_changes(user, ['name', 'age'])  # True
 | |
| 
 | |
| 
 | |
|     This function also supports excluding certain attributes.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         has_changes(user, exclude=['name'])  # False
 | |
| 
 | |
|         has_changes(user, exclude=['age'])  # True
 | |
| 
 | |
|     .. versionchanged: 0.26.6
 | |
|         Added support for multiple attributes and exclude parameter.
 | |
| 
 | |
|     :param obj: SQLAlchemy declarative model object
 | |
|     :param attrs: Names of the attributes
 | |
|     :param exclude: Names of the attributes to exclude
 | |
|     """
 | |
|     if attrs:
 | |
|         if isinstance(attrs, six.string_types):
 | |
|             return (
 | |
|                 sa.inspect(obj)
 | |
|                 .attrs
 | |
|                 .get(attrs)
 | |
|                 .history
 | |
|                 .has_changes()
 | |
|             )
 | |
|         else:
 | |
|             return any(has_changes(obj, attr) for attr in attrs)
 | |
|     else:
 | |
|         if exclude is None:
 | |
|             exclude = []
 | |
|         return any(
 | |
|             attr.history.has_changes()
 | |
|             for key, attr in sa.inspect(obj).attrs.items()
 | |
|             if key not in exclude
 | |
|         )
 | |
| 
 | |
| 
 | |
| def is_loaded(obj, prop):
 | |
|     """
 | |
|     Return whether or not given property of given object has been loaded.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         class Article(Base):
 | |
|             __tablename__ = 'article'
 | |
|             id = sa.Column(sa.Integer, primary_key=True)
 | |
|             name = sa.Column(sa.String)
 | |
|             content = sa.orm.deferred(sa.Column(sa.String))
 | |
| 
 | |
| 
 | |
|         article = session.query(Article).get(5)
 | |
| 
 | |
|         # name gets loaded since its not a deferred property
 | |
|         assert is_loaded(article, 'name')
 | |
| 
 | |
|         # content has not yet been loaded since its a deferred property
 | |
|         assert not is_loaded(article, 'content')
 | |
| 
 | |
| 
 | |
|     .. versionadded: 0.27.8
 | |
| 
 | |
|     :param obj: SQLAlchemy declarative model object
 | |
|     :param prop: Name of the property or InstrumentedAttribute
 | |
|     """
 | |
|     return not isinstance(
 | |
|         getattr(sa.inspect(obj).attrs, prop).loaded_value,
 | |
|         sa.util.langhelpers._symbol
 | |
|     )
 | |
| 
 | |
| 
 | |
| def identity(obj_or_class):
 | |
|     """
 | |
|     Return the identity of given sqlalchemy declarative model class or instance
 | |
|     as a tuple. This differs from obj._sa_instance_state.identity in a way that
 | |
|     it always returns the identity even if object is still in transient state (
 | |
|     new object that is not yet persisted into database). Also for classes it
 | |
|     returns the identity attributes.
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         from sqlalchemy import inspect
 | |
|         from sqlalchemy_utils import identity
 | |
| 
 | |
| 
 | |
|         user = User(name=u'John Matrix')
 | |
|         session.add(user)
 | |
|         identity(user)  # None
 | |
|         inspect(user).identity  # None
 | |
| 
 | |
|         session.flush()  # User now has id but is still in transient state
 | |
| 
 | |
|         identity(user)  # (1,)
 | |
|         inspect(user).identity  # None
 | |
| 
 | |
|         session.commit()
 | |
| 
 | |
|         identity(user)  # (1,)
 | |
|         inspect(user).identity  # (1, )
 | |
| 
 | |
| 
 | |
|     You can also use identity for classes::
 | |
| 
 | |
| 
 | |
|         identity(User)  # (User.id, )
 | |
| 
 | |
|     .. versionadded: 0.21.0
 | |
| 
 | |
|     :param obj: SQLAlchemy declarative model object
 | |
|     """
 | |
|     return tuple(
 | |
|         getattr(obj_or_class, column_key)
 | |
|         for column_key in get_primary_keys(obj_or_class).keys()
 | |
|     )
 | |
| 
 | |
| 
 | |
| def naturally_equivalent(obj, obj2):
 | |
|     """
 | |
|     Returns whether or not two given SQLAlchemy declarative instances are
 | |
|     naturally equivalent (all their non primary key properties are equivalent).
 | |
| 
 | |
| 
 | |
|     ::
 | |
| 
 | |
|         from sqlalchemy_utils import naturally_equivalent
 | |
| 
 | |
| 
 | |
|         user = User(name=u'someone')
 | |
|         user2 = User(name=u'someone')
 | |
| 
 | |
|         user == user2  # False
 | |
| 
 | |
|         naturally_equivalent(user, user2)  # True
 | |
| 
 | |
| 
 | |
|     :param obj: SQLAlchemy declarative model object
 | |
|     :param obj2: SQLAlchemy declarative model object to compare with `obj`
 | |
|     """
 | |
|     for column_key, column in sa.inspect(obj.__class__).columns.items():
 | |
|         if column.primary_key:
 | |
|             continue
 | |
| 
 | |
|         if not (getattr(obj, column_key) == getattr(obj2, column_key)):
 | |
|             return False
 | |
|     return True
 | 
