Files
deb-python-sqlalchemy-utils/sqlalchemy_utils/__init__.py
2013-02-19 11:17:38 +02:00

127 lines
3.7 KiB
Python

from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.sql.expression import desc, asc
def sort_query(query, sort):
"""
Applies an sql ORDER BY for given query
The following examples use the following model definition:
>>> import sqlalchemy as sa
>>> from sqlalchemy import create_engine
>>> from sqlalchemy.orm import sessionmaker
>>> from sqlalchemy.ext.declarative import declarative_base
>>> from sqlalchemy_utils import escape_like, sort_query
>>>
>>>
>>> engine = create_engine(
... 'sqlite:///'
... )
>>> Base = declarative_base()
>>> Session = sessionmaker(bind=engine)
>>> session = Session()
>>>
>>> class Category(Base):
... __tablename__ = 'category'
... id = sa.Column(sa.Integer, primary_key=True)
... name = sa.Column(sa.Unicode(255))
>>>
>>> class Article(Base):
... __tablename__ = 'article'
... id = sa.Column(sa.Integer, primary_key=True)
... name = sa.Column(sa.Unicode(255))
... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
...
... category = sa.orm.relationship(
... Category, primaryjoin=category_id == Category.id
... )
1. Applying simple ascending sort
>>> query = session.query(Article)
>>> query = sort_query(query, 'name')
2. Appying descending sort
>>> query = sort_query(query, '-name')
3. Applying sort to custom calculated label
>>> query = session.query(
... Category, db.func.count(Article.id).label('articles')
... )
>>> query = sort_query(query, 'articles')
4. Applying sort to joined table column
>>> query = session.query(Article).join(Article.category)
>>> query = sort_query(query, 'category-name')
:param query: query to be modified
:param sort: string that defines the label or column to sort the query by
:param errors: whether or not to raise exceptions if unknown sort column
is passed
"""
entities = [entity.entity_zero.class_ for entity in query._entities]
for mapper in query._join_entities:
if isinstance(mapper, Mapper):
entities.append(mapper.class_)
else:
entities.append(mapper)
# get all label names for queries such as:
# db.session.query(Category, db.func.count(Article.id).label('articles'))
labels = []
for entity in query._entities:
if isinstance(entity, _ColumnEntity) and entity._label_name:
labels.append(entity._label_name)
if not sort:
return query
if sort[0] == '-':
func = desc
sort = sort[1:]
else:
func = asc
component = None
parts = sort.split('-')
if len(parts) > 1:
component = parts[0]
sort = parts[1]
if sort in labels:
return query.order_by(func(sort))
for entity in entities:
if component and entity.__table__.name != component:
continue
if sort in entity.__table__.columns:
try:
attr = getattr(entity, sort)
query = query.order_by(func(attr))
except AttributeError:
pass
break
return query
def escape_like(string, escape_char='*'):
"""
Escapes the string paremeter used in SQL LIKE expressions
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)