Files
deb-python-sqlalchemy-utils/sqlalchemy_utils/functions/orm.py
2014-02-20 13:27:17 +02:00

372 lines
8.6 KiB
Python

from functools import partial
from toolz import curry, first
import six
import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.util import AliasedInsp
def remove_property(class_, name):
"""
**Experimental function**
Remove property from declarative class
"""
mapper = class_.mapper
table = class_.__table__
columns = class_.mapper.c
column = columns[name]
del columns._data[name]
del mapper.columns[name]
columns._all_cols.remove(column)
mapper._cols_by_table[table].remove(column)
mapper.class_manager.uninstrument_attribute(name)
del mapper._props[name]
def primary_keys(class_):
"""
Returns all primary keys for given declarative class.
"""
for column in class_.__table__.c:
if column.primary_key:
yield column
def table_name(obj):
"""
Return table name of given target, declarative class or the
table name where the declarative attribute is bound to.
"""
class_ = getattr(obj, 'class_', obj)
try:
return class_.__tablename__
except AttributeError:
pass
try:
return class_.__table__.name
except AttributeError:
pass
def getattrs(obj, attrs):
return map(partial(getattr, obj), attrs)
def mapfirst(iterable):
return map(first, iterable)
@curry
def local_values(prop, entity):
return tuple(getattrs(entity, local_column_names(prop)))
def list_local_values(prop, entities):
return map(local_values(prop), entities)
def remote_values(prop, entity):
return tuple(getattrs(entity, remote_column_names(prop)))
@curry
def local_remote_expr(prop, entity):
return sa.and_(
*[
getattr(remote(prop), r.name)
==
getattr(entity, l.name)
for l, r in prop.local_remote_pairs
if r in remote_column_names(prop)
]
)
def list_local_remote_exprs(prop, entities):
return map(local_remote_expr(prop), entities)
def remote(prop):
try:
return prop.secondary.c
except AttributeError:
return prop.mapper.class_
def local_column_names(prop):
if not hasattr(prop, 'secondary'):
yield prop._discriminator_col.key
yield prop._id_col.key
elif prop.secondary is None:
for local, _ in prop.local_remote_pairs:
yield local.name
else:
if prop.secondary is not None:
for local, remote in prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in prop.parent.tables:
yield local.name
def remote_column_names(prop):
if not hasattr(prop, 'secondary'):
yield '__tablename__'
yield 'id'
elif prop.secondary is None:
for _, remote in prop.local_remote_pairs:
yield remote.name
else:
for _, remote in prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in prop.parent.tables:
yield remote.name
def query_labels(query):
"""
Return all labels for given SQLAlchemy query object.
Example::
query = session.query(
Category,
db.func.count(Article.id).label('articles')
)
query_labels(query) # ('articles', )
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
if isinstance(entity, _ColumnEntity) and entity._label_name:
yield entity._label_name
def query_entities(query):
"""
Return all entities for given SQLAlchemy query object.
Example::
query = session.query(
Category
)
query_entities(query) # ('Category', )
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
yield entity.entity_zero.class_
for entity in query._join_entities:
if isinstance(entity, Mapper):
yield entity.class_
else:
yield entity
def get_query_entity_by_alias(query, alias):
entities = query_entities(query)
if not alias:
return first(entities)
for entity in entities:
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
if name == alias:
return entity
def attrs(expr):
if isinstance(expr, AliasedInsp):
return expr.mapper.attrs
else:
return inspect(expr).attrs
def get_expr_attr(expr, attr_name):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name)
else:
return getattr(expr, attr_name)
def declarative_base(model):
"""
Returns the declarative base for given model class.
:param model: SQLAlchemy declarative model
"""
for parent in model.__bases__:
try:
parent.metadata
return declarative_base(parent)
except AttributeError:
pass
return model
from operator import attrgetter
def getdotattr(obj_or_class, dot_path):
"""
Allow dot-notated strings to be passed to `getattr`.
::
getdotattr(SubSection, 'section.document')
getdotattr(subsection, 'section.document')
:param obj_or_class: Any object or class
:param dot_path: Attribute path with dot mark as separator
"""
last = obj_or_class
# Coerce object style paths to strings.
path = str(dot_path)
for path in dot_path.split('.'):
getter = attrgetter(path)
if isinstance(last, list):
tmp = []
for el in last:
if isinstance(el, list):
tmp.extend(map(getter, el))
else:
tmp.append(getter(el))
last = tmp
elif isinstance(last, InstrumentedAttribute):
last = getter(last.property.mapper.class_)
elif last is None:
return None
else:
last = getter(last)
return last
def has_changes(obj, attr):
"""
Simple shortcut function for checking if given attribute of given
declarative model object has changed during the transaction.
::
from sqlalchemy_utils import has_changes
user = User()
has_changes(user, 'name') # False
user.name = u'someone'
has_changes(user, 'name') # True
:param obj: SQLAlchemy declarative model object
:param attr: Name of the attribute
"""
return (
sa.inspect(obj)
.attrs
.get(attr)
.history
.has_changes()
)
def identity(obj):
"""
Return the identity of given sqlalchemy declarative model instance as a
tuple. This differs from obj._sa_instance_state.identity in a way that it
always returns the identity even if object is still in transient state (
new object that is not yet persisted into database).
::
from sqlalchemy import inspect
from sqlalchemy_utils import identity
user = User(name=u'John Matrix')
session.add(user)
identity(user) # None
inspect(user).identity # None
session.flush() # User now has id but is still in transient state
identity(user) # (1,)
inspect(user).identity # None
session.commit()
identity(user) # (1,)
inspect(user).identity # (1, )
.. versionadded: 0.21.0
:param obj: SQLAlchemy declarative model object
"""
id_ = []
for column in sa.inspect(obj.__class__).columns:
if column.primary_key:
id_.append(getattr(obj, column.name))
if all(value is None for value in id_):
return None
else:
return tuple(id_)
def naturally_equivalent(obj, obj2):
"""
Returns whether or not two given SQLAlchemy declarative instances are
naturally equivalent (all their non primary key properties are equivalent).
::
from sqlalchemy_utils import naturally_equivalent
user = User(name=u'someone')
user2 = User(name=u'someone')
user == user2 # False
naturally_equivalent(user, user2) # True
:param obj: SQLAlchemy declarative model object
:param obj2: SQLAlchemy declarative model object to compare with `obj`
"""
for prop in sa.inspect(obj.__class__).iterate_properties:
if not isinstance(prop, sa.orm.ColumnProperty):
continue
if prop.columns[0].primary_key:
continue
if not (getattr(obj, prop.key) == getattr(obj2, prop.key)):
return False
return True