Added NumberRange types, refactored file structure
This commit is contained in:
@@ -1,372 +1,27 @@
|
|||||||
import phonenumbers
|
from .functions import sort_query, defer_except, escape_like
|
||||||
from functools import wraps
|
from .merge import merge, Merger
|
||||||
import sqlalchemy as sa
|
from .types import (
|
||||||
from sqlalchemy.engine import reflection
|
instrumented_list,
|
||||||
from sqlalchemy.orm import defer, object_session, mapperlib
|
InstrumentedList,
|
||||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
PhoneNumber,
|
||||||
from sqlalchemy.orm.mapper import Mapper
|
PhoneNumberType,
|
||||||
from sqlalchemy.orm.query import _ColumnEntity
|
NumberRange,
|
||||||
from sqlalchemy.orm.properties import ColumnProperty
|
NumberRangeRawType,
|
||||||
from sqlalchemy.sql.expression import desc, asc
|
NumberRangeType
|
||||||
from sqlalchemy import types
|
)
|
||||||
|
|
||||||
|
|
||||||
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
|
__all__ = (
|
||||||
'''
|
sort_query,
|
||||||
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
defer_except,
|
||||||
different phone number formats to attributes, so they can be easily used
|
escape_like,
|
||||||
in templates. Phone number validation method is also implemented.
|
instrumented_list,
|
||||||
|
merge,
|
||||||
Takes the raw phone number and country code as params and parses them
|
InstrumentedList,
|
||||||
into a PhoneNumber object.
|
Merger,
|
||||||
|
NumberRange,
|
||||||
.. _Python phonenumbers library:
|
NumberRangeRawType,
|
||||||
https://github.com/daviddrysdale/python-phonenumbers
|
NumberRangeType,
|
||||||
|
PhoneNumber,
|
||||||
:param raw_number:
|
PhoneNumberType,
|
||||||
String representation of the phone number.
|
)
|
||||||
:param country_code:
|
|
||||||
Country code of the phone number.
|
|
||||||
'''
|
|
||||||
def __init__(self, raw_number, country_code=None):
|
|
||||||
self._phone_number = phonenumbers.parse(raw_number, country_code)
|
|
||||||
super(PhoneNumber, self).__init__(
|
|
||||||
country_code=self._phone_number.country_code,
|
|
||||||
national_number=self._phone_number.national_number,
|
|
||||||
extension=self._phone_number.extension,
|
|
||||||
italian_leading_zero=self._phone_number.italian_leading_zero,
|
|
||||||
raw_input=self._phone_number.raw_input,
|
|
||||||
country_code_source=self._phone_number.country_code_source,
|
|
||||||
preferred_domestic_carrier_code=
|
|
||||||
self._phone_number.preferred_domestic_carrier_code
|
|
||||||
)
|
|
||||||
self.national = phonenumbers.format_number(
|
|
||||||
self._phone_number,
|
|
||||||
phonenumbers.PhoneNumberFormat.NATIONAL
|
|
||||||
)
|
|
||||||
self.international = phonenumbers.format_number(
|
|
||||||
self._phone_number,
|
|
||||||
phonenumbers.PhoneNumberFormat.INTERNATIONAL
|
|
||||||
)
|
|
||||||
self.e164 = phonenumbers.format_number(
|
|
||||||
self._phone_number,
|
|
||||||
phonenumbers.PhoneNumberFormat.E164
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_valid_number(self):
|
|
||||||
return phonenumbers.is_valid_number(self._phone_number)
|
|
||||||
|
|
||||||
|
|
||||||
class PhoneNumberType(types.TypeDecorator):
|
|
||||||
"""
|
|
||||||
Changes PhoneNumber objects to a string representation on the way in and
|
|
||||||
changes them back to PhoneNumber objects on the way out. If E164 is used
|
|
||||||
as storing format, no country code is needed for parsing the database
|
|
||||||
value to PhoneNumber object.
|
|
||||||
"""
|
|
||||||
STORE_FORMAT = 'e164'
|
|
||||||
impl = types.Unicode(20)
|
|
||||||
|
|
||||||
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
|
|
||||||
super(PhoneNumberType, self).__init__(*args, **kwargs)
|
|
||||||
self.country_code = country_code
|
|
||||||
self.impl = types.Unicode(max_length)
|
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
|
||||||
return getattr(value, self.STORE_FORMAT)
|
|
||||||
|
|
||||||
def process_result_value(self, value, dialect):
|
|
||||||
return PhoneNumber(value, self.country_code)
|
|
||||||
|
|
||||||
|
|
||||||
class InstrumentedList(_InstrumentedList):
|
|
||||||
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
|
||||||
additional functionality."""
|
|
||||||
|
|
||||||
def any(self, attr):
|
|
||||||
return any(getattr(item, attr) for item in self)
|
|
||||||
|
|
||||||
def all(self, attr):
|
|
||||||
return all(getattr(item, attr) for item in self)
|
|
||||||
|
|
||||||
|
|
||||||
def instrumented_list(f):
|
|
||||||
@wraps(f)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
return InstrumentedList([item for item in f(*args, **kwargs)])
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def sort_query(query, sort):
|
|
||||||
"""
|
|
||||||
Applies an sql ORDER BY for given query. This function can be easily used
|
|
||||||
with user-defined sorting.
|
|
||||||
|
|
||||||
The examples use the following model definition:
|
|
||||||
|
|
||||||
>>> import sqlalchemy as sa
|
|
||||||
>>> from sqlalchemy import create_engine
|
|
||||||
>>> from sqlalchemy.orm import sessionmaker
|
|
||||||
>>> from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
>>> from sqlalchemy_utils import sort_query
|
|
||||||
>>>
|
|
||||||
>>>
|
|
||||||
>>> engine = create_engine(
|
|
||||||
... 'sqlite:///'
|
|
||||||
... )
|
|
||||||
>>> Base = declarative_base()
|
|
||||||
>>> Session = sessionmaker(bind=engine)
|
|
||||||
>>> session = Session()
|
|
||||||
>>>
|
|
||||||
>>> class Category(Base):
|
|
||||||
... __tablename__ = 'category'
|
|
||||||
... id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
... name = sa.Column(sa.Unicode(255))
|
|
||||||
>>>
|
|
||||||
>>> class Article(Base):
|
|
||||||
... __tablename__ = 'article'
|
|
||||||
... id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
... name = sa.Column(sa.Unicode(255))
|
|
||||||
... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
|
||||||
...
|
|
||||||
... category = sa.orm.relationship(
|
|
||||||
... Category, primaryjoin=category_id == Category.id
|
|
||||||
... )
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
1. Applying simple ascending sort
|
|
||||||
|
|
||||||
>>> query = session.query(Article)
|
|
||||||
>>> query = sort_query(query, 'name')
|
|
||||||
|
|
||||||
2. Appying descending sort
|
|
||||||
|
|
||||||
>>> query = sort_query(query, '-name')
|
|
||||||
|
|
||||||
3. Applying sort to custom calculated label
|
|
||||||
|
|
||||||
>>> query = session.query(
|
|
||||||
... Category, db.func.count(Article.id).label('articles')
|
|
||||||
... )
|
|
||||||
>>> query = sort_query(query, 'articles')
|
|
||||||
|
|
||||||
4. Applying sort to joined table column
|
|
||||||
|
|
||||||
>>> query = session.query(Article).join(Article.category)
|
|
||||||
>>> query = sort_query(query, 'category-name')
|
|
||||||
|
|
||||||
|
|
||||||
:param query: query to be modified
|
|
||||||
:param sort: string that defines the label or column to sort the query by
|
|
||||||
:param errors: whether or not to raise exceptions if unknown sort column
|
|
||||||
is passed
|
|
||||||
"""
|
|
||||||
entities = [entity.entity_zero.class_ for entity in query._entities]
|
|
||||||
for mapper in query._join_entities:
|
|
||||||
if isinstance(mapper, Mapper):
|
|
||||||
entities.append(mapper.class_)
|
|
||||||
else:
|
|
||||||
entities.append(mapper)
|
|
||||||
|
|
||||||
# get all label names for queries such as:
|
|
||||||
# db.session.query(Category, db.func.count(Article.id).label('articles'))
|
|
||||||
labels = []
|
|
||||||
for entity in query._entities:
|
|
||||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
|
||||||
labels.append(entity._label_name)
|
|
||||||
|
|
||||||
if not sort:
|
|
||||||
return query
|
|
||||||
|
|
||||||
if sort[0] == '-':
|
|
||||||
func = desc
|
|
||||||
sort = sort[1:]
|
|
||||||
else:
|
|
||||||
func = asc
|
|
||||||
|
|
||||||
component = None
|
|
||||||
parts = sort.split('-')
|
|
||||||
if len(parts) > 1:
|
|
||||||
component = parts[0]
|
|
||||||
sort = parts[1]
|
|
||||||
if sort in labels:
|
|
||||||
return query.order_by(func(sort))
|
|
||||||
|
|
||||||
for entity in entities:
|
|
||||||
table = entity.__table__
|
|
||||||
if component and table.name != component:
|
|
||||||
continue
|
|
||||||
if sort in table.columns:
|
|
||||||
try:
|
|
||||||
attr = getattr(entity, sort)
|
|
||||||
query = query.order_by(func(attr))
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
break
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
def defer_except(query, columns):
|
|
||||||
"""
|
|
||||||
Deferred loads all columns in given query, except the ones given.
|
|
||||||
|
|
||||||
This function is very useful when working with models with myriad of
|
|
||||||
columns and you want to deferred load many columns.
|
|
||||||
|
|
||||||
>>> from sqlalchemy_utils import defer_except
|
|
||||||
>>> query = session.query(Article)
|
|
||||||
>>> query = defer_except(Article, [Article.id, Article.name])
|
|
||||||
|
|
||||||
:param columns: columns not to deferred load
|
|
||||||
"""
|
|
||||||
model = query._entities[0].entity_zero.class_
|
|
||||||
fields = set(model._sa_class_manager.values())
|
|
||||||
for field in fields:
|
|
||||||
property_ = field.property
|
|
||||||
if isinstance(property_, ColumnProperty):
|
|
||||||
column = property_.columns[0]
|
|
||||||
if column.name not in columns:
|
|
||||||
query = query.options(defer(property_.key))
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
def escape_like(string, escape_char='*'):
|
|
||||||
"""
|
|
||||||
Escapes the string paremeter used in SQL LIKE expressions
|
|
||||||
|
|
||||||
>>> from sqlalchemy_utils import escape_like
|
|
||||||
>>> query = session.query(User).filter(
|
|
||||||
... User.name.ilike(escape_like('John'))
|
|
||||||
... )
|
|
||||||
|
|
||||||
|
|
||||||
:param string: a string to escape
|
|
||||||
:param escape_char: escape character
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
string
|
|
||||||
.replace(escape_char, escape_char * 2)
|
|
||||||
.replace('%', escape_char + '%')
|
|
||||||
.replace('_', escape_char + '_')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dependent_foreign_keys(model_class):
|
|
||||||
"""
|
|
||||||
Returns dependent foreign keys as dicts for given model class.
|
|
||||||
|
|
||||||
** Experimental function **
|
|
||||||
"""
|
|
||||||
session = object_session(model_class)
|
|
||||||
|
|
||||||
engine = session.bind
|
|
||||||
inspector = reflection.Inspector.from_engine(engine)
|
|
||||||
table_names = inspector.get_table_names()
|
|
||||||
|
|
||||||
dependent_foreign_keys = {}
|
|
||||||
|
|
||||||
for table_name in table_names:
|
|
||||||
fks = inspector.get_foreign_keys(table_name)
|
|
||||||
if fks:
|
|
||||||
dependent_foreign_keys[table_name] = []
|
|
||||||
for fk in fks:
|
|
||||||
if fk['referred_table'] == model_class.__tablename__:
|
|
||||||
dependent_foreign_keys[table_name].append(fk)
|
|
||||||
return dependent_foreign_keys
|
|
||||||
|
|
||||||
|
|
||||||
class Merger(object):
|
|
||||||
def memory_merge(self, session, table_name, old_values, new_values):
|
|
||||||
# try to fetch mappers for given table and update in memory objects as
|
|
||||||
# well as database table
|
|
||||||
found = False
|
|
||||||
for mapper in mapperlib._mapper_registry:
|
|
||||||
class_ = mapper.class_
|
|
||||||
if table_name == class_.__table__.name:
|
|
||||||
try:
|
|
||||||
(
|
|
||||||
session.query(mapper.class_)
|
|
||||||
.filter_by(**old_values)
|
|
||||||
.update(
|
|
||||||
new_values,
|
|
||||||
'fetch'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except sa.exc.IntegrityError:
|
|
||||||
pass
|
|
||||||
found = True
|
|
||||||
return found
|
|
||||||
|
|
||||||
def raw_merge(self, session, table, old_values, new_values):
|
|
||||||
conditions = []
|
|
||||||
for key, value in old_values.items():
|
|
||||||
conditions.append(getattr(table.c, key) == value)
|
|
||||||
sql = (
|
|
||||||
table
|
|
||||||
.update()
|
|
||||||
.where(sa.and_(
|
|
||||||
*conditions
|
|
||||||
))
|
|
||||||
.values(
|
|
||||||
new_values
|
|
||||||
)
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
session.execute(sql)
|
|
||||||
except sa.exc.IntegrityError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def merge_update(self, table_name, from_, to, foreign_key):
|
|
||||||
session = object_session(from_)
|
|
||||||
constrained_columns = foreign_key['constrained_columns']
|
|
||||||
referred_columns = foreign_key['referred_columns']
|
|
||||||
metadata = from_.metadata
|
|
||||||
table = metadata.tables[table_name]
|
|
||||||
|
|
||||||
new_values = {}
|
|
||||||
for index, column in enumerate(constrained_columns):
|
|
||||||
new_values[column] = getattr(
|
|
||||||
to, referred_columns[index]
|
|
||||||
)
|
|
||||||
|
|
||||||
old_values = {}
|
|
||||||
for index, column in enumerate(constrained_columns):
|
|
||||||
old_values[column] = getattr(
|
|
||||||
from_, referred_columns[index]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.memory_merge(session, table_name, old_values, new_values):
|
|
||||||
self.raw_merge(session, table, old_values, new_values)
|
|
||||||
|
|
||||||
def __call__(self, from_, to):
|
|
||||||
"""
|
|
||||||
Merges entity into another entity. After merging deletes the from_
|
|
||||||
argument entity.
|
|
||||||
"""
|
|
||||||
if from_.__tablename__ != to.__tablename__:
|
|
||||||
raise Exception()
|
|
||||||
|
|
||||||
session = object_session(from_)
|
|
||||||
foreign_keys = dependent_foreign_keys(from_)
|
|
||||||
|
|
||||||
for table_name in foreign_keys:
|
|
||||||
for foreign_key in foreign_keys[table_name]:
|
|
||||||
self.merge_update(table_name, from_, to, foreign_key)
|
|
||||||
|
|
||||||
session.delete(from_)
|
|
||||||
|
|
||||||
|
|
||||||
def merge(from_, to, merger=Merger):
|
|
||||||
"""
|
|
||||||
Merges entity into another entity. After merging deletes the from_ argument
|
|
||||||
entity.
|
|
||||||
|
|
||||||
After merging the from_ entity is deleted from database.
|
|
||||||
|
|
||||||
:param from_: an entity to merge into another entity
|
|
||||||
:param to: an entity to merge another entity into
|
|
||||||
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
|
|
||||||
class
|
|
||||||
"""
|
|
||||||
return Merger()(from_, to)
|
|
||||||
|
159
sqlalchemy_utils/functions.py
Normal file
159
sqlalchemy_utils/functions.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
|
from sqlalchemy.orm.query import _ColumnEntity
|
||||||
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
|
from sqlalchemy.sql.expression import desc, asc
|
||||||
|
|
||||||
|
|
||||||
|
def sort_query(query, sort):
|
||||||
|
"""
|
||||||
|
Applies an sql ORDER BY for given query. This function can be easily used
|
||||||
|
with user-defined sorting.
|
||||||
|
|
||||||
|
The examples use the following model definition:
|
||||||
|
|
||||||
|
>>> import sqlalchemy as sa
|
||||||
|
>>> from sqlalchemy import create_engine
|
||||||
|
>>> from sqlalchemy.orm import sessionmaker
|
||||||
|
>>> from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
>>> from sqlalchemy_utils import sort_query
|
||||||
|
>>>
|
||||||
|
>>>
|
||||||
|
>>> engine = create_engine(
|
||||||
|
... 'sqlite:///'
|
||||||
|
... )
|
||||||
|
>>> Base = declarative_base()
|
||||||
|
>>> Session = sessionmaker(bind=engine)
|
||||||
|
>>> session = Session()
|
||||||
|
>>>
|
||||||
|
>>> class Category(Base):
|
||||||
|
... __tablename__ = 'category'
|
||||||
|
... id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
... name = sa.Column(sa.Unicode(255))
|
||||||
|
>>>
|
||||||
|
>>> class Article(Base):
|
||||||
|
... __tablename__ = 'article'
|
||||||
|
... id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
... name = sa.Column(sa.Unicode(255))
|
||||||
|
... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||||
|
...
|
||||||
|
... category = sa.orm.relationship(
|
||||||
|
... Category, primaryjoin=category_id == Category.id
|
||||||
|
... )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
1. Applying simple ascending sort
|
||||||
|
|
||||||
|
>>> query = session.query(Article)
|
||||||
|
>>> query = sort_query(query, 'name')
|
||||||
|
|
||||||
|
2. Appying descending sort
|
||||||
|
|
||||||
|
>>> query = sort_query(query, '-name')
|
||||||
|
|
||||||
|
3. Applying sort to custom calculated label
|
||||||
|
|
||||||
|
>>> query = session.query(
|
||||||
|
... Category, db.func.count(Article.id).label('articles')
|
||||||
|
... )
|
||||||
|
>>> query = sort_query(query, 'articles')
|
||||||
|
|
||||||
|
4. Applying sort to joined table column
|
||||||
|
|
||||||
|
>>> query = session.query(Article).join(Article.category)
|
||||||
|
>>> query = sort_query(query, 'category-name')
|
||||||
|
|
||||||
|
|
||||||
|
:param query: query to be modified
|
||||||
|
:param sort: string that defines the label or column to sort the query by
|
||||||
|
:param errors: whether or not to raise exceptions if unknown sort column
|
||||||
|
is passed
|
||||||
|
"""
|
||||||
|
entities = [entity.entity_zero.class_ for entity in query._entities]
|
||||||
|
for mapper in query._join_entities:
|
||||||
|
if isinstance(mapper, Mapper):
|
||||||
|
entities.append(mapper.class_)
|
||||||
|
else:
|
||||||
|
entities.append(mapper)
|
||||||
|
|
||||||
|
# get all label names for queries such as:
|
||||||
|
# db.session.query(Category, db.func.count(Article.id).label('articles'))
|
||||||
|
labels = []
|
||||||
|
for entity in query._entities:
|
||||||
|
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||||
|
labels.append(entity._label_name)
|
||||||
|
|
||||||
|
if not sort:
|
||||||
|
return query
|
||||||
|
|
||||||
|
if sort[0] == '-':
|
||||||
|
func = desc
|
||||||
|
sort = sort[1:]
|
||||||
|
else:
|
||||||
|
func = asc
|
||||||
|
|
||||||
|
component = None
|
||||||
|
parts = sort.split('-')
|
||||||
|
if len(parts) > 1:
|
||||||
|
component = parts[0]
|
||||||
|
sort = parts[1]
|
||||||
|
if sort in labels:
|
||||||
|
return query.order_by(func(sort))
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
table = entity.__table__
|
||||||
|
if component and table.name != component:
|
||||||
|
continue
|
||||||
|
if sort in table.columns:
|
||||||
|
try:
|
||||||
|
attr = getattr(entity, sort)
|
||||||
|
query = query.order_by(func(attr))
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
def defer_except(query, columns):
|
||||||
|
"""
|
||||||
|
Deferred loads all columns in given query, except the ones given.
|
||||||
|
|
||||||
|
This function is very useful when working with models with myriad of
|
||||||
|
columns and you want to deferred load many columns.
|
||||||
|
|
||||||
|
>>> from sqlalchemy_utils import defer_except
|
||||||
|
>>> query = session.query(Article)
|
||||||
|
>>> query = defer_except(Article, [Article.id, Article.name])
|
||||||
|
|
||||||
|
:param columns: columns not to deferred load
|
||||||
|
"""
|
||||||
|
model = query._entities[0].entity_zero.class_
|
||||||
|
fields = set(model._sa_class_manager.values())
|
||||||
|
for field in fields:
|
||||||
|
property_ = field.property
|
||||||
|
if isinstance(property_, ColumnProperty):
|
||||||
|
column = property_.columns[0]
|
||||||
|
if column.name not in columns:
|
||||||
|
query = query.options(defer(property_.key))
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
def escape_like(string, escape_char='*'):
|
||||||
|
"""
|
||||||
|
Escapes the string paremeter used in SQL LIKE expressions
|
||||||
|
|
||||||
|
>>> from sqlalchemy_utils import escape_like
|
||||||
|
>>> query = session.query(User).filter(
|
||||||
|
... User.name.ilike(escape_like('John'))
|
||||||
|
... )
|
||||||
|
|
||||||
|
|
||||||
|
:param string: a string to escape
|
||||||
|
:param escape_char: escape character
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
string
|
||||||
|
.replace(escape_char, escape_char * 2)
|
||||||
|
.replace('%', escape_char + '%')
|
||||||
|
.replace('_', escape_char + '_')
|
||||||
|
)
|
123
sqlalchemy_utils/merge.py
Normal file
123
sqlalchemy_utils/merge.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.engine import reflection
|
||||||
|
from sqlalchemy.orm import object_session, mapperlib
|
||||||
|
|
||||||
|
|
||||||
|
def dependent_foreign_keys(model_class):
|
||||||
|
"""
|
||||||
|
Returns dependent foreign keys as dicts for given model class.
|
||||||
|
|
||||||
|
** Experimental function **
|
||||||
|
"""
|
||||||
|
session = object_session(model_class)
|
||||||
|
|
||||||
|
engine = session.bind
|
||||||
|
inspector = reflection.Inspector.from_engine(engine)
|
||||||
|
table_names = inspector.get_table_names()
|
||||||
|
|
||||||
|
dependent_foreign_keys = {}
|
||||||
|
|
||||||
|
for table_name in table_names:
|
||||||
|
fks = inspector.get_foreign_keys(table_name)
|
||||||
|
if fks:
|
||||||
|
dependent_foreign_keys[table_name] = []
|
||||||
|
for fk in fks:
|
||||||
|
if fk['referred_table'] == model_class.__tablename__:
|
||||||
|
dependent_foreign_keys[table_name].append(fk)
|
||||||
|
return dependent_foreign_keys
|
||||||
|
|
||||||
|
|
||||||
|
class Merger(object):
|
||||||
|
def memory_merge(self, session, table_name, old_values, new_values):
|
||||||
|
# try to fetch mappers for given table and update in memory objects as
|
||||||
|
# well as database table
|
||||||
|
found = False
|
||||||
|
for mapper in mapperlib._mapper_registry:
|
||||||
|
class_ = mapper.class_
|
||||||
|
if table_name == class_.__table__.name:
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
session.query(mapper.class_)
|
||||||
|
.filter_by(**old_values)
|
||||||
|
.update(
|
||||||
|
new_values,
|
||||||
|
'fetch'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except sa.exc.IntegrityError:
|
||||||
|
pass
|
||||||
|
found = True
|
||||||
|
return found
|
||||||
|
|
||||||
|
def raw_merge(self, session, table, old_values, new_values):
|
||||||
|
conditions = []
|
||||||
|
for key, value in old_values.items():
|
||||||
|
conditions.append(getattr(table.c, key) == value)
|
||||||
|
sql = (
|
||||||
|
table
|
||||||
|
.update()
|
||||||
|
.where(sa.and_(
|
||||||
|
*conditions
|
||||||
|
))
|
||||||
|
.values(
|
||||||
|
new_values
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
session.execute(sql)
|
||||||
|
except sa.exc.IntegrityError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def merge_update(self, table_name, from_, to, foreign_key):
|
||||||
|
session = object_session(from_)
|
||||||
|
constrained_columns = foreign_key['constrained_columns']
|
||||||
|
referred_columns = foreign_key['referred_columns']
|
||||||
|
metadata = from_.metadata
|
||||||
|
table = metadata.tables[table_name]
|
||||||
|
|
||||||
|
new_values = {}
|
||||||
|
for index, column in enumerate(constrained_columns):
|
||||||
|
new_values[column] = getattr(
|
||||||
|
to, referred_columns[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
old_values = {}
|
||||||
|
for index, column in enumerate(constrained_columns):
|
||||||
|
old_values[column] = getattr(
|
||||||
|
from_, referred_columns[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.memory_merge(session, table_name, old_values, new_values):
|
||||||
|
self.raw_merge(session, table, old_values, new_values)
|
||||||
|
|
||||||
|
def __call__(self, from_, to):
|
||||||
|
"""
|
||||||
|
Merges entity into another entity. After merging deletes the from_
|
||||||
|
argument entity.
|
||||||
|
"""
|
||||||
|
if from_.__tablename__ != to.__tablename__:
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
session = object_session(from_)
|
||||||
|
foreign_keys = dependent_foreign_keys(from_)
|
||||||
|
|
||||||
|
for table_name in foreign_keys:
|
||||||
|
for foreign_key in foreign_keys[table_name]:
|
||||||
|
self.merge_update(table_name, from_, to, foreign_key)
|
||||||
|
|
||||||
|
session.delete(from_)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(from_, to, merger=Merger):
|
||||||
|
"""
|
||||||
|
Merges entity into another entity. After merging deletes the from_ argument
|
||||||
|
entity.
|
||||||
|
|
||||||
|
After merging the from_ entity is deleted from database.
|
||||||
|
|
||||||
|
:param from_: an entity to merge into another entity
|
||||||
|
:param to: an entity to merge another entity into
|
||||||
|
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
|
||||||
|
class
|
||||||
|
"""
|
||||||
|
return Merger()(from_, to)
|
179
sqlalchemy_utils/types.py
Normal file
179
sqlalchemy_utils/types.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import phonenumbers
|
||||||
|
from functools import wraps
|
||||||
|
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||||
|
from sqlalchemy import types
|
||||||
|
|
||||||
|
|
||||||
|
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
|
||||||
|
'''
|
||||||
|
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
||||||
|
different phone number formats to attributes, so they can be easily used
|
||||||
|
in templates. Phone number validation method is also implemented.
|
||||||
|
|
||||||
|
Takes the raw phone number and country code as params and parses them
|
||||||
|
into a PhoneNumber object.
|
||||||
|
|
||||||
|
.. _Python phonenumbers library:
|
||||||
|
https://github.com/daviddrysdale/python-phonenumbers
|
||||||
|
|
||||||
|
:param raw_number:
|
||||||
|
String representation of the phone number.
|
||||||
|
:param country_code:
|
||||||
|
Country code of the phone number.
|
||||||
|
'''
|
||||||
|
def __init__(self, raw_number, country_code=None):
|
||||||
|
self._phone_number = phonenumbers.parse(raw_number, country_code)
|
||||||
|
super(PhoneNumber, self).__init__(
|
||||||
|
country_code=self._phone_number.country_code,
|
||||||
|
national_number=self._phone_number.national_number,
|
||||||
|
extension=self._phone_number.extension,
|
||||||
|
italian_leading_zero=self._phone_number.italian_leading_zero,
|
||||||
|
raw_input=self._phone_number.raw_input,
|
||||||
|
country_code_source=self._phone_number.country_code_source,
|
||||||
|
preferred_domestic_carrier_code=
|
||||||
|
self._phone_number.preferred_domestic_carrier_code
|
||||||
|
)
|
||||||
|
self.national = phonenumbers.format_number(
|
||||||
|
self._phone_number,
|
||||||
|
phonenumbers.PhoneNumberFormat.NATIONAL
|
||||||
|
)
|
||||||
|
self.international = phonenumbers.format_number(
|
||||||
|
self._phone_number,
|
||||||
|
phonenumbers.PhoneNumberFormat.INTERNATIONAL
|
||||||
|
)
|
||||||
|
self.e164 = phonenumbers.format_number(
|
||||||
|
self._phone_number,
|
||||||
|
phonenumbers.PhoneNumberFormat.E164
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_valid_number(self):
|
||||||
|
return phonenumbers.is_valid_number(self._phone_number)
|
||||||
|
|
||||||
|
|
||||||
|
class PhoneNumberType(types.TypeDecorator):
|
||||||
|
"""
|
||||||
|
Changes PhoneNumber objects to a string representation on the way in and
|
||||||
|
changes them back to PhoneNumber objects on the way out. If E164 is used
|
||||||
|
as storing format, no country code is needed for parsing the database
|
||||||
|
value to PhoneNumber object.
|
||||||
|
"""
|
||||||
|
STORE_FORMAT = 'e164'
|
||||||
|
impl = types.Unicode(20)
|
||||||
|
|
||||||
|
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
|
||||||
|
super(PhoneNumberType, self).__init__(*args, **kwargs)
|
||||||
|
self.country_code = country_code
|
||||||
|
self.impl = types.Unicode(max_length)
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return getattr(value, self.STORE_FORMAT)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
return PhoneNumber(value, self.country_code)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberRangeRawType(types.UserDefinedType):
|
||||||
|
"""
|
||||||
|
Raw number range type, only supports PostgreSQL for now.
|
||||||
|
"""
|
||||||
|
def get_col_spec(self):
|
||||||
|
return 'int4range'
|
||||||
|
|
||||||
|
|
||||||
|
class NumberRangeType(types.TypeDecorator):
|
||||||
|
impl = NumberRangeRawType
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return value
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
return NumberRange.from_normalized_str(value)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberRange(object):
|
||||||
|
def __init__(self, min_value, max_value):
|
||||||
|
self.min_value = min_value
|
||||||
|
self.max_value = max_value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_normalized_str(cls, value):
|
||||||
|
if value is not None:
|
||||||
|
values = value[1:-1].split(',')
|
||||||
|
min_value, max_value = map(
|
||||||
|
lambda a: int(a.strip()), values
|
||||||
|
)
|
||||||
|
|
||||||
|
if value[0] == '(':
|
||||||
|
min_value += 1
|
||||||
|
|
||||||
|
if value[1] == ')':
|
||||||
|
max_value -= 1
|
||||||
|
|
||||||
|
return cls(min_value, max_value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, value):
|
||||||
|
if value is not None:
|
||||||
|
values = value.split('-')
|
||||||
|
min_value, max_value = map(
|
||||||
|
lambda a: int(a.strip()), values
|
||||||
|
)
|
||||||
|
return cls(min_value, max_value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'NumberRange(%r, %r)' % (self.min_value, self.max_value)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '[%s, %s]' % (self.min_value, self.max_value)
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
try:
|
||||||
|
return NumberRange(
|
||||||
|
self.min_value + other.min_value,
|
||||||
|
self.max_value + other.max_value
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __iadd__(self, other):
|
||||||
|
try:
|
||||||
|
self.min_value += other.min_value
|
||||||
|
self.max_value += other.max_value
|
||||||
|
return self
|
||||||
|
except AttributeError:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __sub__(self, other):
|
||||||
|
try:
|
||||||
|
return NumberRange(
|
||||||
|
self.min_value - other.min_value,
|
||||||
|
self.max_value - other.max_value
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __isub__(self, other):
|
||||||
|
try:
|
||||||
|
self.min_value -= other.min_value
|
||||||
|
self.max_value -= other.max_value
|
||||||
|
return self
|
||||||
|
except AttributeError:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class InstrumentedList(_InstrumentedList):
|
||||||
|
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
||||||
|
additional functionality."""
|
||||||
|
|
||||||
|
def any(self, attr):
|
||||||
|
return any(getattr(item, attr) for item in self)
|
||||||
|
|
||||||
|
def all(self, attr):
|
||||||
|
return all(getattr(item, attr) for item in self)
|
||||||
|
|
||||||
|
|
||||||
|
def instrumented_list(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return InstrumentedList([item for item in f(*args, **kwargs)])
|
||||||
|
return wrapper
|
397
tests.py
397
tests.py
@@ -1,397 +0,0 @@
|
|||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from sqlalchemy_utils import (
|
|
||||||
escape_like,
|
|
||||||
sort_query,
|
|
||||||
InstrumentedList,
|
|
||||||
PhoneNumber,
|
|
||||||
PhoneNumberType,
|
|
||||||
merge
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCase(object):
|
|
||||||
|
|
||||||
def setup_method(self, method):
|
|
||||||
self.engine = create_engine('sqlite:///:memory:')
|
|
||||||
self.Base = declarative_base()
|
|
||||||
|
|
||||||
self.create_models()
|
|
||||||
self.Base.metadata.create_all(self.engine)
|
|
||||||
|
|
||||||
Session = sessionmaker(bind=self.engine)
|
|
||||||
self.session = Session()
|
|
||||||
|
|
||||||
def teardown_method(self, method):
|
|
||||||
self.session.close_all()
|
|
||||||
self.Base.metadata.drop_all(self.engine)
|
|
||||||
self.engine.dispose()
|
|
||||||
|
|
||||||
def create_models(self):
|
|
||||||
class User(self.Base):
|
|
||||||
__tablename__ = 'user'
|
|
||||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
phone_number = sa.Column(PhoneNumberType())
|
|
||||||
|
|
||||||
class Category(self.Base):
|
|
||||||
__tablename__ = 'category'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
class Article(self.Base):
|
|
||||||
__tablename__ = 'article'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
|
||||||
|
|
||||||
category = sa.orm.relationship(
|
|
||||||
Category,
|
|
||||||
primaryjoin=category_id == Category.id,
|
|
||||||
backref=sa.orm.backref(
|
|
||||||
'articles',
|
|
||||||
collection_class=InstrumentedList
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.User = User
|
|
||||||
self.Category = Category
|
|
||||||
self.Article = Article
|
|
||||||
|
|
||||||
|
|
||||||
class TestInstrumentedList(TestCase):
|
|
||||||
def test_any_returns_true_if_member_has_attr_defined(self):
|
|
||||||
category = self.Category()
|
|
||||||
category.articles.append(self.Article())
|
|
||||||
category.articles.append(self.Article(name=u'some name'))
|
|
||||||
assert category.articles.any('name')
|
|
||||||
|
|
||||||
def test_any_returns_false_if_no_member_has_attr_defined(self):
|
|
||||||
category = self.Category()
|
|
||||||
category.articles.append(self.Article())
|
|
||||||
assert not category.articles.any('name')
|
|
||||||
|
|
||||||
|
|
||||||
class TestEscapeLike(TestCase):
|
|
||||||
def test_escapes_wildcards(self):
|
|
||||||
assert escape_like('_*%') == '*_***%'
|
|
||||||
|
|
||||||
|
|
||||||
class TestSortQuery(TestCase):
|
|
||||||
def test_without_sort_param_returns_the_query_object_untouched(self):
|
|
||||||
query = self.session.query(self.Article)
|
|
||||||
sorted_query = sort_query(query, '')
|
|
||||||
assert query == sorted_query
|
|
||||||
|
|
||||||
def test_sort_by_column_ascending(self):
|
|
||||||
query = sort_query(self.session.query(self.Article), 'name')
|
|
||||||
assert 'ORDER BY article.name ASC' in str(query)
|
|
||||||
|
|
||||||
def test_sort_by_column_descending(self):
|
|
||||||
query = sort_query(self.session.query(self.Article), '-name')
|
|
||||||
assert 'ORDER BY article.name DESC' in str(query)
|
|
||||||
|
|
||||||
def test_skips_unknown_columns(self):
|
|
||||||
query = self.session.query(self.Article)
|
|
||||||
sorted_query = sort_query(query, '-unknown')
|
|
||||||
assert query == sorted_query
|
|
||||||
|
|
||||||
def test_sort_by_calculated_value_ascending(self):
|
|
||||||
query = self.session.query(
|
|
||||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
|
||||||
)
|
|
||||||
query = sort_query(query, 'articles')
|
|
||||||
assert 'ORDER BY articles ASC' in str(query)
|
|
||||||
|
|
||||||
def test_sort_by_calculated_value_descending(self):
|
|
||||||
query = self.session.query(
|
|
||||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
|
||||||
)
|
|
||||||
query = sort_query(query, '-articles')
|
|
||||||
assert 'ORDER BY articles DESC' in str(query)
|
|
||||||
|
|
||||||
def test_sort_by_joined_table_column(self):
|
|
||||||
query = self.session.query(self.Article).join(self.Article.category)
|
|
||||||
sorted_query = sort_query(query, 'category-name')
|
|
||||||
assert 'category.name ASC' in str(sorted_query)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPhoneNumber(object):
|
|
||||||
def setup_method(self, method):
|
|
||||||
self.valid_phone_numbers = [
|
|
||||||
'040 1234567',
|
|
||||||
'+358 401234567',
|
|
||||||
'09 2501234',
|
|
||||||
'+358 92501234',
|
|
||||||
'0800 939393',
|
|
||||||
'09 4243 0456',
|
|
||||||
'0600 900 500'
|
|
||||||
]
|
|
||||||
self.invalid_phone_numbers = [
|
|
||||||
'abc',
|
|
||||||
'+040 1234567',
|
|
||||||
'0111234567',
|
|
||||||
'358'
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_valid_phone_numbers(self):
|
|
||||||
for raw_number in self.valid_phone_numbers:
|
|
||||||
phone_number = PhoneNumber(raw_number, 'FI')
|
|
||||||
assert phone_number.is_valid_number()
|
|
||||||
|
|
||||||
def test_invalid_phone_numbers(self):
|
|
||||||
for raw_number in self.invalid_phone_numbers:
|
|
||||||
try:
|
|
||||||
phone_number = PhoneNumber(raw_number, 'FI')
|
|
||||||
assert not phone_number.is_valid_number()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_phone_number_attributes(self):
|
|
||||||
phone_number = PhoneNumber('+358401234567')
|
|
||||||
assert phone_number.e164 == u'+358401234567'
|
|
||||||
assert phone_number.international == u'+358 40 1234567'
|
|
||||||
assert phone_number.national == u'040 1234567'
|
|
||||||
|
|
||||||
|
|
||||||
class TestPhoneNumberType(TestCase):
|
|
||||||
def setup_method(self, method):
|
|
||||||
super(TestPhoneNumberType, self).setup_method(method)
|
|
||||||
self.phone_number = PhoneNumber(
|
|
||||||
'040 1234567',
|
|
||||||
'FI'
|
|
||||||
)
|
|
||||||
self.user = self.User()
|
|
||||||
self.user.name = u'Someone'
|
|
||||||
self.user.phone_number = self.phone_number
|
|
||||||
self.session.add(self.user)
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
def test_query_returns_phone_number_object(self):
|
|
||||||
queried_user = self.session.query(self.User).first()
|
|
||||||
assert queried_user.phone_number == self.phone_number
|
|
||||||
|
|
||||||
def test_phone_number_is_stored_as_string(self):
|
|
||||||
result = self.session.execute(
|
|
||||||
'SELECT phone_number FROM user WHERE id=:param',
|
|
||||||
{'param': self.user.id}
|
|
||||||
)
|
|
||||||
assert result.first()[0] == u'+358401234567'
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseTestCase(object):
|
|
||||||
def create_models(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup_method(self, method):
|
|
||||||
self.engine = create_engine(
|
|
||||||
'sqlite:///'
|
|
||||||
)
|
|
||||||
#self.engine.echo = True
|
|
||||||
self.Base = declarative_base()
|
|
||||||
|
|
||||||
self.create_models()
|
|
||||||
self.Base.metadata.create_all(self.engine)
|
|
||||||
Session = sessionmaker(bind=self.engine)
|
|
||||||
self.session = Session()
|
|
||||||
|
|
||||||
def teardown_method(self, method):
|
|
||||||
self.engine.dispose()
|
|
||||||
self.Base.metadata.drop_all(self.engine)
|
|
||||||
self.session.expunge_all()
|
|
||||||
|
|
||||||
|
|
||||||
class TestMerge(DatabaseTestCase):
|
|
||||||
def create_models(self):
|
|
||||||
class User(self.Base):
|
|
||||||
__tablename__ = 'user'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'User(%r)' % self.name
|
|
||||||
|
|
||||||
class BlogPost(self.Base):
|
|
||||||
__tablename__ = 'blog_post'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
title = sa.Column(sa.Unicode(255))
|
|
||||||
content = sa.Column(sa.UnicodeText)
|
|
||||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
|
||||||
|
|
||||||
author = sa.orm.relationship(User)
|
|
||||||
|
|
||||||
self.User = User
|
|
||||||
self.BlogPost = BlogPost
|
|
||||||
|
|
||||||
def test_updates_foreign_keys(self):
|
|
||||||
john = self.User(name=u'John')
|
|
||||||
jack = self.User(name=u'Jack')
|
|
||||||
post = self.BlogPost(title=u'Some title', author=john)
|
|
||||||
post2 = self.BlogPost(title=u'Other title', author=jack)
|
|
||||||
self.session.add(john)
|
|
||||||
self.session.add(jack)
|
|
||||||
self.session.add(post)
|
|
||||||
self.session.add(post2)
|
|
||||||
self.session.commit()
|
|
||||||
merge(john, jack)
|
|
||||||
assert post.author == jack
|
|
||||||
assert post2.author == jack
|
|
||||||
|
|
||||||
def test_deletes_from_entity(self):
|
|
||||||
john = self.User(name=u'John')
|
|
||||||
jack = self.User(name=u'Jack')
|
|
||||||
self.session.add(john)
|
|
||||||
self.session.add(jack)
|
|
||||||
self.session.commit()
|
|
||||||
merge(john, jack)
|
|
||||||
assert john in self.session.deleted
|
|
||||||
|
|
||||||
|
|
||||||
class TestMergeManyToManyAssociations(DatabaseTestCase):
|
|
||||||
def create_models(self):
|
|
||||||
class User(self.Base):
|
|
||||||
__tablename__ = 'user'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'User(%r)' % self.name
|
|
||||||
|
|
||||||
team_member = sa.Table(
|
|
||||||
'team_member', self.Base.metadata,
|
|
||||||
sa.Column(
|
|
||||||
'user_id', sa.Integer,
|
|
||||||
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
|
||||||
primary_key=True
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
'team_id', sa.Integer,
|
|
||||||
sa.ForeignKey('team.id', ondelete='CASCADE'),
|
|
||||||
primary_key=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
class Team(self.Base):
|
|
||||||
__tablename__ = 'team'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
members = sa.orm.relationship(
|
|
||||||
User,
|
|
||||||
secondary=team_member,
|
|
||||||
backref='teams'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.User = User
|
|
||||||
self.Team = Team
|
|
||||||
|
|
||||||
def test_when_association_only_exists_in_from_entity(self):
|
|
||||||
john = self.User(name=u'John')
|
|
||||||
jack = self.User(name=u'Jack')
|
|
||||||
team = self.Team(name=u'Team')
|
|
||||||
team.members.append(john)
|
|
||||||
self.session.add(john)
|
|
||||||
self.session.add(jack)
|
|
||||||
self.session.commit()
|
|
||||||
merge(john, jack)
|
|
||||||
assert john not in team.members
|
|
||||||
assert jack in team.members
|
|
||||||
|
|
||||||
def test_when_association_exists_in_both(self):
|
|
||||||
john = self.User(name=u'John')
|
|
||||||
jack = self.User(name=u'Jack')
|
|
||||||
team = self.Team(name=u'Team')
|
|
||||||
team.members.append(john)
|
|
||||||
team.members.append(jack)
|
|
||||||
self.session.add(john)
|
|
||||||
self.session.add(jack)
|
|
||||||
self.session.commit()
|
|
||||||
merge(john, jack)
|
|
||||||
assert john not in team.members
|
|
||||||
assert jack in team.members
|
|
||||||
count = self.session.execute(
|
|
||||||
'SELECT COUNT(1) FROM team_member'
|
|
||||||
).fetchone()[0]
|
|
||||||
assert count == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestMergeManyToManyAssociationObjects(DatabaseTestCase):
|
|
||||||
def create_models(self):
|
|
||||||
class Team(self.Base):
|
|
||||||
__tablename__ = 'team'
|
|
||||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
class User(self.Base):
|
|
||||||
__tablename__ = 'user'
|
|
||||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
class TeamMember(self.Base):
|
|
||||||
__tablename__ = 'team_member'
|
|
||||||
user_id = sa.Column(
|
|
||||||
sa.Integer,
|
|
||||||
sa.ForeignKey(User.id, ondelete='CASCADE'),
|
|
||||||
primary_key=True
|
|
||||||
)
|
|
||||||
team_id = sa.Column(
|
|
||||||
sa.Integer,
|
|
||||||
sa.ForeignKey(Team.id, ondelete='CASCADE'),
|
|
||||||
primary_key=True
|
|
||||||
)
|
|
||||||
role = sa.Column(sa.Unicode(255))
|
|
||||||
team = sa.orm.relationship(
|
|
||||||
Team,
|
|
||||||
backref=sa.orm.backref(
|
|
||||||
'members',
|
|
||||||
cascade='all, delete-orphan'
|
|
||||||
),
|
|
||||||
primaryjoin=team_id == Team.id,
|
|
||||||
)
|
|
||||||
user = sa.orm.relationship(
|
|
||||||
User,
|
|
||||||
backref=sa.orm.backref(
|
|
||||||
'memberships',
|
|
||||||
cascade='all, delete-orphan'
|
|
||||||
),
|
|
||||||
primaryjoin=user_id == User.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.User = User
|
|
||||||
self.TeamMember = TeamMember
|
|
||||||
self.Team = Team
|
|
||||||
|
|
||||||
def test_when_association_only_exists_in_from_entity(self):
|
|
||||||
john = self.User(name=u'John')
|
|
||||||
jack = self.User(name=u'Jack')
|
|
||||||
team = self.Team(name=u'Team')
|
|
||||||
team.members.append(self.TeamMember(user=john))
|
|
||||||
self.session.add(john)
|
|
||||||
self.session.add(jack)
|
|
||||||
self.session.add(team)
|
|
||||||
self.session.commit()
|
|
||||||
merge(john, jack)
|
|
||||||
self.session.commit()
|
|
||||||
users = [member.user for member in team.members]
|
|
||||||
assert john not in users
|
|
||||||
assert jack in users
|
|
||||||
|
|
||||||
def test_when_association_exists_in_both(self):
|
|
||||||
john = self.User(name=u'John')
|
|
||||||
jack = self.User(name=u'Jack')
|
|
||||||
team = self.Team(name=u'Team')
|
|
||||||
team.members.append(self.TeamMember(user=john))
|
|
||||||
team.members.append(self.TeamMember(user=jack))
|
|
||||||
self.session.add(john)
|
|
||||||
self.session.add(jack)
|
|
||||||
self.session.add(team)
|
|
||||||
self.session.commit()
|
|
||||||
merge(john, jack)
|
|
||||||
users = [member.user for member in team.members]
|
|
||||||
assert john not in users
|
|
||||||
assert jack in users
|
|
||||||
assert self.session.query(self.TeamMember).count() == 1
|
|
86
tests/__init__.py
Normal file
86
tests/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from sqlalchemy_utils import (
|
||||||
|
escape_like,
|
||||||
|
sort_query,
|
||||||
|
InstrumentedList,
|
||||||
|
PhoneNumber,
|
||||||
|
PhoneNumberType,
|
||||||
|
merge
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase(object):
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.engine = create_engine(
|
||||||
|
'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||||
|
)
|
||||||
|
self.Base = declarative_base()
|
||||||
|
|
||||||
|
self.create_models()
|
||||||
|
self.Base.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
Session = sessionmaker(bind=self.engine)
|
||||||
|
self.session = Session()
|
||||||
|
|
||||||
|
def teardown_method(self, method):
|
||||||
|
self.session.close_all()
|
||||||
|
self.Base.metadata.drop_all(self.engine)
|
||||||
|
self.engine.dispose()
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
phone_number = sa.Column(PhoneNumberType())
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
|
__tablename__ = 'category'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Article(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||||
|
|
||||||
|
category = sa.orm.relationship(
|
||||||
|
Category,
|
||||||
|
primaryjoin=category_id == Category.id,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'articles',
|
||||||
|
collection_class=InstrumentedList
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Category = Category
|
||||||
|
self.Article = Article
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseTestCase(object):
|
||||||
|
def create_models(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.engine = create_engine(
|
||||||
|
'sqlite:///'
|
||||||
|
)
|
||||||
|
#self.engine.echo = True
|
||||||
|
self.Base = declarative_base()
|
||||||
|
|
||||||
|
self.create_models()
|
||||||
|
self.Base.metadata.create_all(self.engine)
|
||||||
|
Session = sessionmaker(bind=self.engine)
|
||||||
|
self.session = Session()
|
||||||
|
|
||||||
|
def teardown_method(self, method):
|
||||||
|
self.engine.dispose()
|
||||||
|
self.Base.metadata.drop_all(self.engine)
|
||||||
|
self.session.expunge_all()
|
14
tests/test_instrumented_list.py
Normal file
14
tests/test_instrumented_list.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestInstrumentedList(TestCase):
|
||||||
|
def test_any_returns_true_if_member_has_attr_defined(self):
|
||||||
|
category = self.Category()
|
||||||
|
category.articles.append(self.Article())
|
||||||
|
category.articles.append(self.Article(name=u'some name'))
|
||||||
|
assert category.articles.any('name')
|
||||||
|
|
||||||
|
def test_any_returns_false_if_no_member_has_attr_defined(self):
|
||||||
|
category = self.Category()
|
||||||
|
category.articles.append(self.Article())
|
||||||
|
assert not category.articles.any('name')
|
196
tests/test_merge.py
Normal file
196
tests/test_merge.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import merge
|
||||||
|
|
||||||
|
from tests import DatabaseTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestMerge(DatabaseTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'User(%r)' % self.name
|
||||||
|
|
||||||
|
class BlogPost(self.Base):
|
||||||
|
__tablename__ = 'blog_post'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
title = sa.Column(sa.Unicode(255))
|
||||||
|
content = sa.Column(sa.UnicodeText)
|
||||||
|
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||||
|
|
||||||
|
author = sa.orm.relationship(User)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.BlogPost = BlogPost
|
||||||
|
|
||||||
|
def test_updates_foreign_keys(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
post = self.BlogPost(title=u'Some title', author=john)
|
||||||
|
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.add(post)
|
||||||
|
self.session.add(post2)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert post.author == jack
|
||||||
|
assert post2.author == jack
|
||||||
|
|
||||||
|
def test_deletes_from_entity(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert john in self.session.deleted
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeManyToManyAssociations(DatabaseTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'User(%r)' % self.name
|
||||||
|
|
||||||
|
team_member = sa.Table(
|
||||||
|
'team_member', self.Base.metadata,
|
||||||
|
sa.Column(
|
||||||
|
'user_id', sa.Integer,
|
||||||
|
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'team_id', sa.Integer,
|
||||||
|
sa.ForeignKey('team.id', ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Team(self.Base):
|
||||||
|
__tablename__ = 'team'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
members = sa.orm.relationship(
|
||||||
|
User,
|
||||||
|
secondary=team_member,
|
||||||
|
backref='teams'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Team = Team
|
||||||
|
|
||||||
|
def test_when_association_only_exists_in_from_entity(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(john)
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert john not in team.members
|
||||||
|
assert jack in team.members
|
||||||
|
|
||||||
|
def test_when_association_exists_in_both(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(john)
|
||||||
|
team.members.append(jack)
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert john not in team.members
|
||||||
|
assert jack in team.members
|
||||||
|
count = self.session.execute(
|
||||||
|
'SELECT COUNT(1) FROM team_member'
|
||||||
|
).fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeManyToManyAssociationObjects(DatabaseTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Team(self.Base):
|
||||||
|
__tablename__ = 'team'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class TeamMember(self.Base):
|
||||||
|
__tablename__ = 'team_member'
|
||||||
|
user_id = sa.Column(
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey(User.id, ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
)
|
||||||
|
team_id = sa.Column(
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey(Team.id, ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
)
|
||||||
|
role = sa.Column(sa.Unicode(255))
|
||||||
|
team = sa.orm.relationship(
|
||||||
|
Team,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'members',
|
||||||
|
cascade='all, delete-orphan'
|
||||||
|
),
|
||||||
|
primaryjoin=team_id == Team.id,
|
||||||
|
)
|
||||||
|
user = sa.orm.relationship(
|
||||||
|
User,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'memberships',
|
||||||
|
cascade='all, delete-orphan'
|
||||||
|
),
|
||||||
|
primaryjoin=user_id == User.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.TeamMember = TeamMember
|
||||||
|
self.Team = Team
|
||||||
|
|
||||||
|
def test_when_association_only_exists_in_from_entity(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(self.TeamMember(user=john))
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.add(team)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
self.session.commit()
|
||||||
|
users = [member.user for member in team.members]
|
||||||
|
assert john not in users
|
||||||
|
assert jack in users
|
||||||
|
|
||||||
|
def test_when_association_exists_in_both(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(self.TeamMember(user=john))
|
||||||
|
team.members.append(self.TeamMember(user=jack))
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.add(team)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
users = [member.user for member in team.members]
|
||||||
|
assert john not in users
|
||||||
|
assert jack in users
|
||||||
|
assert self.session.query(self.TeamMember).count() == 1
|
65
tests/test_phonenumber_type.py
Normal file
65
tests/test_phonenumber_type.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from tests import TestCase
|
||||||
|
from sqlalchemy_utils import PhoneNumber
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhoneNumber(object):
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.valid_phone_numbers = [
|
||||||
|
'040 1234567',
|
||||||
|
'+358 401234567',
|
||||||
|
'09 2501234',
|
||||||
|
'+358 92501234',
|
||||||
|
'0800 939393',
|
||||||
|
'09 4243 0456',
|
||||||
|
'0600 900 500'
|
||||||
|
]
|
||||||
|
self.invalid_phone_numbers = [
|
||||||
|
'abc',
|
||||||
|
'+040 1234567',
|
||||||
|
'0111234567',
|
||||||
|
'358'
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_valid_phone_numbers(self):
|
||||||
|
for raw_number in self.valid_phone_numbers:
|
||||||
|
phone_number = PhoneNumber(raw_number, 'FI')
|
||||||
|
assert phone_number.is_valid_number()
|
||||||
|
|
||||||
|
def test_invalid_phone_numbers(self):
|
||||||
|
for raw_number in self.invalid_phone_numbers:
|
||||||
|
try:
|
||||||
|
phone_number = PhoneNumber(raw_number, 'FI')
|
||||||
|
assert not phone_number.is_valid_number()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_phone_number_attributes(self):
|
||||||
|
phone_number = PhoneNumber('+358401234567')
|
||||||
|
assert phone_number.e164 == u'+358401234567'
|
||||||
|
assert phone_number.international == u'+358 40 1234567'
|
||||||
|
assert phone_number.national == u'040 1234567'
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhoneNumberType(TestCase):
|
||||||
|
def setup_method(self, method):
|
||||||
|
super(TestPhoneNumberType, self).setup_method(method)
|
||||||
|
self.phone_number = PhoneNumber(
|
||||||
|
'040 1234567',
|
||||||
|
'FI'
|
||||||
|
)
|
||||||
|
self.user = self.User()
|
||||||
|
self.user.name = u'Someone'
|
||||||
|
self.user.phone_number = self.phone_number
|
||||||
|
self.session.add(self.user)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def test_query_returns_phone_number_object(self):
|
||||||
|
queried_user = self.session.query(self.User).first()
|
||||||
|
assert queried_user.phone_number == self.phone_number
|
||||||
|
|
||||||
|
def test_phone_number_is_stored_as_string(self):
|
||||||
|
result = self.session.execute(
|
||||||
|
'SELECT phone_number FROM "user" WHERE id=:param',
|
||||||
|
{'param': self.user.id}
|
||||||
|
)
|
||||||
|
assert result.first()[0] == u'+358401234567'
|
47
tests/test_utility_functions.py
Normal file
47
tests/test_utility_functions.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import escape_like, sort_query
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestEscapeLike(TestCase):
|
||||||
|
def test_escapes_wildcards(self):
|
||||||
|
assert escape_like('_*%') == '*_***%'
|
||||||
|
|
||||||
|
|
||||||
|
class TestSortQuery(TestCase):
|
||||||
|
def test_without_sort_param_returns_the_query_object_untouched(self):
|
||||||
|
query = self.session.query(self.Article)
|
||||||
|
sorted_query = sort_query(query, '')
|
||||||
|
assert query == sorted_query
|
||||||
|
|
||||||
|
def test_sort_by_column_ascending(self):
|
||||||
|
query = sort_query(self.session.query(self.Article), 'name')
|
||||||
|
assert 'ORDER BY article.name ASC' in str(query)
|
||||||
|
|
||||||
|
def test_sort_by_column_descending(self):
|
||||||
|
query = sort_query(self.session.query(self.Article), '-name')
|
||||||
|
assert 'ORDER BY article.name DESC' in str(query)
|
||||||
|
|
||||||
|
def test_skips_unknown_columns(self):
|
||||||
|
query = self.session.query(self.Article)
|
||||||
|
sorted_query = sort_query(query, '-unknown')
|
||||||
|
assert query == sorted_query
|
||||||
|
|
||||||
|
def test_sort_by_calculated_value_ascending(self):
|
||||||
|
query = self.session.query(
|
||||||
|
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||||
|
)
|
||||||
|
query = sort_query(query, 'articles')
|
||||||
|
assert 'ORDER BY articles ASC' in str(query)
|
||||||
|
|
||||||
|
def test_sort_by_calculated_value_descending(self):
|
||||||
|
query = self.session.query(
|
||||||
|
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||||
|
)
|
||||||
|
query = sort_query(query, '-articles')
|
||||||
|
assert 'ORDER BY articles DESC' in str(query)
|
||||||
|
|
||||||
|
def test_sort_by_joined_table_column(self):
|
||||||
|
query = self.session.query(self.Article).join(self.Article.category)
|
||||||
|
sorted_query = sort_query(query, 'category-name')
|
||||||
|
assert 'category.name ASC' in str(sorted_query)
|
Reference in New Issue
Block a user