Added NumberRange types, refactored file structure

This commit is contained in:
Konsta Vesterinen
2013-03-26 13:03:43 +02:00
parent 965b64c5bf
commit 26db1397d5
10 changed files with 896 additions and 769 deletions

View File

@@ -1,372 +1,27 @@
import phonenumbers
from functools import wraps
import sqlalchemy as sa
from sqlalchemy.engine import reflection
from sqlalchemy.orm import defer, object_session, mapperlib
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.sql.expression import desc, asc
from sqlalchemy import types
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
'''
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
different phone number formats to attributes, so they can be easily used
in templates. Phone number validation method is also implemented.
Takes the raw phone number and country code as params and parses them
into a PhoneNumber object.
.. _Python phonenumbers library:
https://github.com/daviddrysdale/python-phonenumbers
:param raw_number:
String representation of the phone number.
:param country_code:
Country code of the phone number.
'''
def __init__(self, raw_number, country_code=None):
self._phone_number = phonenumbers.parse(raw_number, country_code)
super(PhoneNumber, self).__init__(
country_code=self._phone_number.country_code,
national_number=self._phone_number.national_number,
extension=self._phone_number.extension,
italian_leading_zero=self._phone_number.italian_leading_zero,
raw_input=self._phone_number.raw_input,
country_code_source=self._phone_number.country_code_source,
preferred_domestic_carrier_code=
self._phone_number.preferred_domestic_carrier_code
)
self.national = phonenumbers.format_number(
self._phone_number,
phonenumbers.PhoneNumberFormat.NATIONAL
)
self.international = phonenumbers.format_number(
self._phone_number,
phonenumbers.PhoneNumberFormat.INTERNATIONAL
)
self.e164 = phonenumbers.format_number(
self._phone_number,
phonenumbers.PhoneNumberFormat.E164
)
def is_valid_number(self):
return phonenumbers.is_valid_number(self._phone_number)
class PhoneNumberType(types.TypeDecorator):
"""
Changes PhoneNumber objects to a string representation on the way in and
changes them back to PhoneNumber objects on the way out. If E164 is used
as storing format, no country code is needed for parsing the database
value to PhoneNumber object.
"""
STORE_FORMAT = 'e164'
impl = types.Unicode(20)
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
super(PhoneNumberType, self).__init__(*args, **kwargs)
self.country_code = country_code
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
return getattr(value, self.STORE_FORMAT)
def process_result_value(self, value, dialect):
return PhoneNumber(value, self.country_code)
class InstrumentedList(_InstrumentedList):
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
additional functionality."""
def any(self, attr):
return any(getattr(item, attr) for item in self)
def all(self, attr):
return all(getattr(item, attr) for item in self)
def instrumented_list(f):
@wraps(f)
def wrapper(*args, **kwargs):
return InstrumentedList([item for item in f(*args, **kwargs)])
return wrapper
def sort_query(query, sort):
"""
Applies an sql ORDER BY for given query. This function can be easily used
with user-defined sorting.
The examples use the following model definition:
>>> import sqlalchemy as sa
>>> from sqlalchemy import create_engine
>>> from sqlalchemy.orm import sessionmaker
>>> from sqlalchemy.ext.declarative import declarative_base
>>> from sqlalchemy_utils import sort_query
>>>
>>>
>>> engine = create_engine(
... 'sqlite:///'
... )
>>> Base = declarative_base()
>>> Session = sessionmaker(bind=engine)
>>> session = Session()
>>>
>>> class Category(Base):
... __tablename__ = 'category'
... id = sa.Column(sa.Integer, primary_key=True)
... name = sa.Column(sa.Unicode(255))
>>>
>>> class Article(Base):
... __tablename__ = 'article'
... id = sa.Column(sa.Integer, primary_key=True)
... name = sa.Column(sa.Unicode(255))
... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
...
... category = sa.orm.relationship(
... Category, primaryjoin=category_id == Category.id
... )
1. Applying simple ascending sort
>>> query = session.query(Article)
>>> query = sort_query(query, 'name')
2. Appying descending sort
>>> query = sort_query(query, '-name')
3. Applying sort to custom calculated label
>>> query = session.query(
... Category, db.func.count(Article.id).label('articles')
... )
>>> query = sort_query(query, 'articles')
4. Applying sort to joined table column
>>> query = session.query(Article).join(Article.category)
>>> query = sort_query(query, 'category-name')
:param query: query to be modified
:param sort: string that defines the label or column to sort the query by
:param errors: whether or not to raise exceptions if unknown sort column
is passed
"""
entities = [entity.entity_zero.class_ for entity in query._entities]
for mapper in query._join_entities:
if isinstance(mapper, Mapper):
entities.append(mapper.class_)
else:
entities.append(mapper)
# get all label names for queries such as:
# db.session.query(Category, db.func.count(Article.id).label('articles'))
labels = []
for entity in query._entities:
if isinstance(entity, _ColumnEntity) and entity._label_name:
labels.append(entity._label_name)
if not sort:
return query
if sort[0] == '-':
func = desc
sort = sort[1:]
else:
func = asc
component = None
parts = sort.split('-')
if len(parts) > 1:
component = parts[0]
sort = parts[1]
if sort in labels:
return query.order_by(func(sort))
for entity in entities:
table = entity.__table__
if component and table.name != component:
continue
if sort in table.columns:
try:
attr = getattr(entity, sort)
query = query.order_by(func(attr))
except AttributeError:
pass
break
return query
def defer_except(query, columns):
"""
Deferred loads all columns in given query, except the ones given.
This function is very useful when working with models with myriad of
columns and you want to deferred load many columns.
>>> from sqlalchemy_utils import defer_except
>>> query = session.query(Article)
>>> query = defer_except(Article, [Article.id, Article.name])
:param columns: columns not to deferred load
"""
model = query._entities[0].entity_zero.class_
fields = set(model._sa_class_manager.values())
for field in fields:
property_ = field.property
if isinstance(property_, ColumnProperty):
column = property_.columns[0]
if column.name not in columns:
query = query.options(defer(property_.key))
return query
def escape_like(string, escape_char='*'):
"""
Escapes the string paremeter used in SQL LIKE expressions
>>> from sqlalchemy_utils import escape_like
>>> query = session.query(User).filter(
... User.name.ilike(escape_like('John'))
... )
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)
def dependent_foreign_keys(model_class):
"""
Returns dependent foreign keys as dicts for given model class.
** Experimental function **
"""
session = object_session(model_class)
engine = session.bind
inspector = reflection.Inspector.from_engine(engine)
table_names = inspector.get_table_names()
dependent_foreign_keys = {}
for table_name in table_names:
fks = inspector.get_foreign_keys(table_name)
if fks:
dependent_foreign_keys[table_name] = []
for fk in fks:
if fk['referred_table'] == model_class.__tablename__:
dependent_foreign_keys[table_name].append(fk)
return dependent_foreign_keys
class Merger(object):
def memory_merge(self, session, table_name, old_values, new_values):
# try to fetch mappers for given table and update in memory objects as
# well as database table
found = False
for mapper in mapperlib._mapper_registry:
class_ = mapper.class_
if table_name == class_.__table__.name:
try:
(
session.query(mapper.class_)
.filter_by(**old_values)
.update(
new_values,
'fetch'
)
)
except sa.exc.IntegrityError:
pass
found = True
return found
def raw_merge(self, session, table, old_values, new_values):
conditions = []
for key, value in old_values.items():
conditions.append(getattr(table.c, key) == value)
sql = (
table
.update()
.where(sa.and_(
*conditions
))
.values(
new_values
)
)
try:
session.execute(sql)
except sa.exc.IntegrityError:
pass
def merge_update(self, table_name, from_, to, foreign_key):
session = object_session(from_)
constrained_columns = foreign_key['constrained_columns']
referred_columns = foreign_key['referred_columns']
metadata = from_.metadata
table = metadata.tables[table_name]
new_values = {}
for index, column in enumerate(constrained_columns):
new_values[column] = getattr(
to, referred_columns[index]
)
old_values = {}
for index, column in enumerate(constrained_columns):
old_values[column] = getattr(
from_, referred_columns[index]
)
if not self.memory_merge(session, table_name, old_values, new_values):
self.raw_merge(session, table, old_values, new_values)
def __call__(self, from_, to):
"""
Merges entity into another entity. After merging deletes the from_
argument entity.
"""
if from_.__tablename__ != to.__tablename__:
raise Exception()
session = object_session(from_)
foreign_keys = dependent_foreign_keys(from_)
for table_name in foreign_keys:
for foreign_key in foreign_keys[table_name]:
self.merge_update(table_name, from_, to, foreign_key)
session.delete(from_)
def merge(from_, to, merger=Merger):
"""
Merges entity into another entity. After merging deletes the from_ argument
entity.
After merging the from_ entity is deleted from database.
:param from_: an entity to merge into another entity
:param to: an entity to merge another entity into
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
class
"""
return Merger()(from_, to)
from .functions import sort_query, defer_except, escape_like
from .merge import merge, Merger
from .types import (
instrumented_list,
InstrumentedList,
PhoneNumber,
PhoneNumberType,
NumberRange,
NumberRangeRawType,
NumberRangeType
)
__all__ = (
sort_query,
defer_except,
escape_like,
instrumented_list,
merge,
InstrumentedList,
Merger,
NumberRange,
NumberRangeRawType,
NumberRangeType,
PhoneNumber,
PhoneNumberType,
)

View File

@@ -0,0 +1,159 @@
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.sql.expression import desc, asc
def sort_query(query, sort):
"""
Applies an sql ORDER BY for given query. This function can be easily used
with user-defined sorting.
The examples use the following model definition:
>>> import sqlalchemy as sa
>>> from sqlalchemy import create_engine
>>> from sqlalchemy.orm import sessionmaker
>>> from sqlalchemy.ext.declarative import declarative_base
>>> from sqlalchemy_utils import sort_query
>>>
>>>
>>> engine = create_engine(
... 'sqlite:///'
... )
>>> Base = declarative_base()
>>> Session = sessionmaker(bind=engine)
>>> session = Session()
>>>
>>> class Category(Base):
... __tablename__ = 'category'
... id = sa.Column(sa.Integer, primary_key=True)
... name = sa.Column(sa.Unicode(255))
>>>
>>> class Article(Base):
... __tablename__ = 'article'
... id = sa.Column(sa.Integer, primary_key=True)
... name = sa.Column(sa.Unicode(255))
... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
...
... category = sa.orm.relationship(
... Category, primaryjoin=category_id == Category.id
... )
1. Applying simple ascending sort
>>> query = session.query(Article)
>>> query = sort_query(query, 'name')
2. Appying descending sort
>>> query = sort_query(query, '-name')
3. Applying sort to custom calculated label
>>> query = session.query(
... Category, db.func.count(Article.id).label('articles')
... )
>>> query = sort_query(query, 'articles')
4. Applying sort to joined table column
>>> query = session.query(Article).join(Article.category)
>>> query = sort_query(query, 'category-name')
:param query: query to be modified
:param sort: string that defines the label or column to sort the query by
:param errors: whether or not to raise exceptions if unknown sort column
is passed
"""
entities = [entity.entity_zero.class_ for entity in query._entities]
for mapper in query._join_entities:
if isinstance(mapper, Mapper):
entities.append(mapper.class_)
else:
entities.append(mapper)
# get all label names for queries such as:
# db.session.query(Category, db.func.count(Article.id).label('articles'))
labels = []
for entity in query._entities:
if isinstance(entity, _ColumnEntity) and entity._label_name:
labels.append(entity._label_name)
if not sort:
return query
if sort[0] == '-':
func = desc
sort = sort[1:]
else:
func = asc
component = None
parts = sort.split('-')
if len(parts) > 1:
component = parts[0]
sort = parts[1]
if sort in labels:
return query.order_by(func(sort))
for entity in entities:
table = entity.__table__
if component and table.name != component:
continue
if sort in table.columns:
try:
attr = getattr(entity, sort)
query = query.order_by(func(attr))
except AttributeError:
pass
break
return query
def defer_except(query, columns):
"""
Deferred loads all columns in given query, except the ones given.
This function is very useful when working with models with myriad of
columns and you want to deferred load many columns.
>>> from sqlalchemy_utils import defer_except
>>> query = session.query(Article)
>>> query = defer_except(Article, [Article.id, Article.name])
:param columns: columns not to deferred load
"""
model = query._entities[0].entity_zero.class_
fields = set(model._sa_class_manager.values())
for field in fields:
property_ = field.property
if isinstance(property_, ColumnProperty):
column = property_.columns[0]
if column.name not in columns:
query = query.options(defer(property_.key))
return query
def escape_like(string, escape_char='*'):
"""
Escapes the string paremeter used in SQL LIKE expressions
>>> from sqlalchemy_utils import escape_like
>>> query = session.query(User).filter(
... User.name.ilike(escape_like('John'))
... )
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)

123
sqlalchemy_utils/merge.py Normal file
View File

@@ -0,0 +1,123 @@
import sqlalchemy as sa
from sqlalchemy.engine import reflection
from sqlalchemy.orm import object_session, mapperlib
def dependent_foreign_keys(model_class):
"""
Returns dependent foreign keys as dicts for given model class.
** Experimental function **
"""
session = object_session(model_class)
engine = session.bind
inspector = reflection.Inspector.from_engine(engine)
table_names = inspector.get_table_names()
dependent_foreign_keys = {}
for table_name in table_names:
fks = inspector.get_foreign_keys(table_name)
if fks:
dependent_foreign_keys[table_name] = []
for fk in fks:
if fk['referred_table'] == model_class.__tablename__:
dependent_foreign_keys[table_name].append(fk)
return dependent_foreign_keys
class Merger(object):
def memory_merge(self, session, table_name, old_values, new_values):
# try to fetch mappers for given table and update in memory objects as
# well as database table
found = False
for mapper in mapperlib._mapper_registry:
class_ = mapper.class_
if table_name == class_.__table__.name:
try:
(
session.query(mapper.class_)
.filter_by(**old_values)
.update(
new_values,
'fetch'
)
)
except sa.exc.IntegrityError:
pass
found = True
return found
def raw_merge(self, session, table, old_values, new_values):
conditions = []
for key, value in old_values.items():
conditions.append(getattr(table.c, key) == value)
sql = (
table
.update()
.where(sa.and_(
*conditions
))
.values(
new_values
)
)
try:
session.execute(sql)
except sa.exc.IntegrityError:
pass
def merge_update(self, table_name, from_, to, foreign_key):
session = object_session(from_)
constrained_columns = foreign_key['constrained_columns']
referred_columns = foreign_key['referred_columns']
metadata = from_.metadata
table = metadata.tables[table_name]
new_values = {}
for index, column in enumerate(constrained_columns):
new_values[column] = getattr(
to, referred_columns[index]
)
old_values = {}
for index, column in enumerate(constrained_columns):
old_values[column] = getattr(
from_, referred_columns[index]
)
if not self.memory_merge(session, table_name, old_values, new_values):
self.raw_merge(session, table, old_values, new_values)
def __call__(self, from_, to):
"""
Merges entity into another entity. After merging deletes the from_
argument entity.
"""
if from_.__tablename__ != to.__tablename__:
raise Exception()
session = object_session(from_)
foreign_keys = dependent_foreign_keys(from_)
for table_name in foreign_keys:
for foreign_key in foreign_keys[table_name]:
self.merge_update(table_name, from_, to, foreign_key)
session.delete(from_)
def merge(from_, to, merger=Merger):
"""
Merges entity into another entity. After merging deletes the from_ argument
entity.
After merging the from_ entity is deleted from database.
:param from_: an entity to merge into another entity
:param to: an entity to merge another entity into
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
class
"""
return Merger()(from_, to)

179
sqlalchemy_utils/types.py Normal file
View File

@@ -0,0 +1,179 @@
import phonenumbers
from functools import wraps
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
from sqlalchemy import types
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
'''
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
different phone number formats to attributes, so they can be easily used
in templates. Phone number validation method is also implemented.
Takes the raw phone number and country code as params and parses them
into a PhoneNumber object.
.. _Python phonenumbers library:
https://github.com/daviddrysdale/python-phonenumbers
:param raw_number:
String representation of the phone number.
:param country_code:
Country code of the phone number.
'''
def __init__(self, raw_number, country_code=None):
self._phone_number = phonenumbers.parse(raw_number, country_code)
super(PhoneNumber, self).__init__(
country_code=self._phone_number.country_code,
national_number=self._phone_number.national_number,
extension=self._phone_number.extension,
italian_leading_zero=self._phone_number.italian_leading_zero,
raw_input=self._phone_number.raw_input,
country_code_source=self._phone_number.country_code_source,
preferred_domestic_carrier_code=
self._phone_number.preferred_domestic_carrier_code
)
self.national = phonenumbers.format_number(
self._phone_number,
phonenumbers.PhoneNumberFormat.NATIONAL
)
self.international = phonenumbers.format_number(
self._phone_number,
phonenumbers.PhoneNumberFormat.INTERNATIONAL
)
self.e164 = phonenumbers.format_number(
self._phone_number,
phonenumbers.PhoneNumberFormat.E164
)
def is_valid_number(self):
return phonenumbers.is_valid_number(self._phone_number)
class PhoneNumberType(types.TypeDecorator):
"""
Changes PhoneNumber objects to a string representation on the way in and
changes them back to PhoneNumber objects on the way out. If E164 is used
as storing format, no country code is needed for parsing the database
value to PhoneNumber object.
"""
STORE_FORMAT = 'e164'
impl = types.Unicode(20)
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
super(PhoneNumberType, self).__init__(*args, **kwargs)
self.country_code = country_code
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
return getattr(value, self.STORE_FORMAT)
def process_result_value(self, value, dialect):
return PhoneNumber(value, self.country_code)
class NumberRangeRawType(types.UserDefinedType):
"""
Raw number range type, only supports PostgreSQL for now.
"""
def get_col_spec(self):
return 'int4range'
class NumberRangeType(types.TypeDecorator):
impl = NumberRangeRawType
def process_bind_param(self, value, dialect):
return value
def process_result_value(self, value, dialect):
return NumberRange.from_normalized_str(value)
class NumberRange(object):
def __init__(self, min_value, max_value):
self.min_value = min_value
self.max_value = max_value
@classmethod
def from_normalized_str(cls, value):
if value is not None:
values = value[1:-1].split(',')
min_value, max_value = map(
lambda a: int(a.strip()), values
)
if value[0] == '(':
min_value += 1
if value[1] == ')':
max_value -= 1
return cls(min_value, max_value)
@classmethod
def from_str(cls, value):
if value is not None:
values = value.split('-')
min_value, max_value = map(
lambda a: int(a.strip()), values
)
return cls(min_value, max_value)
def __repr__(self):
return 'NumberRange(%r, %r)' % (self.min_value, self.max_value)
def __str__(self):
return '[%s, %s]' % (self.min_value, self.max_value)
def __add__(self, other):
try:
return NumberRange(
self.min_value + other.min_value,
self.max_value + other.max_value
)
except AttributeError:
return NotImplemented
def __iadd__(self, other):
try:
self.min_value += other.min_value
self.max_value += other.max_value
return self
except AttributeError:
return NotImplemented
def __sub__(self, other):
try:
return NumberRange(
self.min_value - other.min_value,
self.max_value - other.max_value
)
except AttributeError:
return NotImplemented
def __isub__(self, other):
try:
self.min_value -= other.min_value
self.max_value -= other.max_value
return self
except AttributeError:
return NotImplemented
class InstrumentedList(_InstrumentedList):
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
additional functionality."""
def any(self, attr):
return any(getattr(item, attr) for item in self)
def all(self, attr):
return all(getattr(item, attr) for item in self)
def instrumented_list(f):
@wraps(f)
def wrapper(*args, **kwargs):
return InstrumentedList([item for item in f(*args, **kwargs)])
return wrapper

397
tests.py
View File

@@ -1,397 +0,0 @@
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import (
escape_like,
sort_query,
InstrumentedList,
PhoneNumber,
PhoneNumberType,
merge
)
class TestCase(object):
def setup_method(self, method):
self.engine = create_engine('sqlite:///:memory:')
self.Base = declarative_base()
self.create_models()
self.Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def teardown_method(self, method):
self.session.close_all()
self.Base.metadata.drop_all(self.engine)
self.engine.dispose()
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
phone_number = sa.Column(PhoneNumberType())
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
self.User = User
self.Category = Category
self.Article = Article
class TestInstrumentedList(TestCase):
def test_any_returns_true_if_member_has_attr_defined(self):
category = self.Category()
category.articles.append(self.Article())
category.articles.append(self.Article(name=u'some name'))
assert category.articles.any('name')
def test_any_returns_false_if_no_member_has_attr_defined(self):
category = self.Category()
category.articles.append(self.Article())
assert not category.articles.any('name')
class TestEscapeLike(TestCase):
def test_escapes_wildcards(self):
assert escape_like('_*%') == '*_***%'
class TestSortQuery(TestCase):
def test_without_sort_param_returns_the_query_object_untouched(self):
query = self.session.query(self.Article)
sorted_query = sort_query(query, '')
assert query == sorted_query
def test_sort_by_column_ascending(self):
query = sort_query(self.session.query(self.Article), 'name')
assert 'ORDER BY article.name ASC' in str(query)
def test_sort_by_column_descending(self):
query = sort_query(self.session.query(self.Article), '-name')
assert 'ORDER BY article.name DESC' in str(query)
def test_skips_unknown_columns(self):
query = self.session.query(self.Article)
sorted_query = sort_query(query, '-unknown')
assert query == sorted_query
def test_sort_by_calculated_value_ascending(self):
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, 'articles')
assert 'ORDER BY articles ASC' in str(query)
def test_sort_by_calculated_value_descending(self):
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, '-articles')
assert 'ORDER BY articles DESC' in str(query)
def test_sort_by_joined_table_column(self):
query = self.session.query(self.Article).join(self.Article.category)
sorted_query = sort_query(query, 'category-name')
assert 'category.name ASC' in str(sorted_query)
class TestPhoneNumber(object):
def setup_method(self, method):
self.valid_phone_numbers = [
'040 1234567',
'+358 401234567',
'09 2501234',
'+358 92501234',
'0800 939393',
'09 4243 0456',
'0600 900 500'
]
self.invalid_phone_numbers = [
'abc',
'+040 1234567',
'0111234567',
'358'
]
def test_valid_phone_numbers(self):
for raw_number in self.valid_phone_numbers:
phone_number = PhoneNumber(raw_number, 'FI')
assert phone_number.is_valid_number()
def test_invalid_phone_numbers(self):
for raw_number in self.invalid_phone_numbers:
try:
phone_number = PhoneNumber(raw_number, 'FI')
assert not phone_number.is_valid_number()
except:
pass
def test_phone_number_attributes(self):
phone_number = PhoneNumber('+358401234567')
assert phone_number.e164 == u'+358401234567'
assert phone_number.international == u'+358 40 1234567'
assert phone_number.national == u'040 1234567'
class TestPhoneNumberType(TestCase):
def setup_method(self, method):
super(TestPhoneNumberType, self).setup_method(method)
self.phone_number = PhoneNumber(
'040 1234567',
'FI'
)
self.user = self.User()
self.user.name = u'Someone'
self.user.phone_number = self.phone_number
self.session.add(self.user)
self.session.commit()
def test_query_returns_phone_number_object(self):
queried_user = self.session.query(self.User).first()
assert queried_user.phone_number == self.phone_number
def test_phone_number_is_stored_as_string(self):
result = self.session.execute(
'SELECT phone_number FROM user WHERE id=:param',
{'param': self.user.id}
)
assert result.first()[0] == u'+358401234567'
class DatabaseTestCase(object):
def create_models(self):
pass
def setup_method(self, method):
self.engine = create_engine(
'sqlite:///'
)
#self.engine.echo = True
self.Base = declarative_base()
self.create_models()
self.Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def teardown_method(self, method):
self.engine.dispose()
self.Base.metadata.drop_all(self.engine)
self.session.expunge_all()
class TestMerge(DatabaseTestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def __repr__(self):
return 'User(%r)' % self.name
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.Unicode(255))
content = sa.Column(sa.UnicodeText)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
self.User = User
self.BlogPost = BlogPost
def test_updates_foreign_keys(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack)
self.session.add(john)
self.session.add(jack)
self.session.add(post)
self.session.add(post2)
self.session.commit()
merge(john, jack)
assert post.author == jack
assert post2.author == jack
def test_deletes_from_entity(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
self.session.add(john)
self.session.add(jack)
self.session.commit()
merge(john, jack)
assert john in self.session.deleted
class TestMergeManyToManyAssociations(DatabaseTestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def __repr__(self):
return 'User(%r)' % self.name
team_member = sa.Table(
'team_member', self.Base.metadata,
sa.Column(
'user_id', sa.Integer,
sa.ForeignKey('user.id', ondelete='CASCADE'),
primary_key=True
),
sa.Column(
'team_id', sa.Integer,
sa.ForeignKey('team.id', ondelete='CASCADE'),
primary_key=True
)
)
class Team(self.Base):
__tablename__ = 'team'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
members = sa.orm.relationship(
User,
secondary=team_member,
backref='teams'
)
self.User = User
self.Team = Team
def test_when_association_only_exists_in_from_entity(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(john)
self.session.add(john)
self.session.add(jack)
self.session.commit()
merge(john, jack)
assert john not in team.members
assert jack in team.members
def test_when_association_exists_in_both(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(john)
team.members.append(jack)
self.session.add(john)
self.session.add(jack)
self.session.commit()
merge(john, jack)
assert john not in team.members
assert jack in team.members
count = self.session.execute(
'SELECT COUNT(1) FROM team_member'
).fetchone()[0]
assert count == 1
class TestMergeManyToManyAssociationObjects(DatabaseTestCase):
def create_models(self):
class Team(self.Base):
__tablename__ = 'team'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class TeamMember(self.Base):
__tablename__ = 'team_member'
user_id = sa.Column(
sa.Integer,
sa.ForeignKey(User.id, ondelete='CASCADE'),
primary_key=True
)
team_id = sa.Column(
sa.Integer,
sa.ForeignKey(Team.id, ondelete='CASCADE'),
primary_key=True
)
role = sa.Column(sa.Unicode(255))
team = sa.orm.relationship(
Team,
backref=sa.orm.backref(
'members',
cascade='all, delete-orphan'
),
primaryjoin=team_id == Team.id,
)
user = sa.orm.relationship(
User,
backref=sa.orm.backref(
'memberships',
cascade='all, delete-orphan'
),
primaryjoin=user_id == User.id,
)
self.User = User
self.TeamMember = TeamMember
self.Team = Team
def test_when_association_only_exists_in_from_entity(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(self.TeamMember(user=john))
self.session.add(john)
self.session.add(jack)
self.session.add(team)
self.session.commit()
merge(john, jack)
self.session.commit()
users = [member.user for member in team.members]
assert john not in users
assert jack in users
def test_when_association_exists_in_both(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(self.TeamMember(user=john))
team.members.append(self.TeamMember(user=jack))
self.session.add(john)
self.session.add(jack)
self.session.add(team)
self.session.commit()
merge(john, jack)
users = [member.user for member in team.members]
assert john not in users
assert jack in users
assert self.session.query(self.TeamMember).count() == 1

86
tests/__init__.py Normal file
View File

@@ -0,0 +1,86 @@
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import (
escape_like,
sort_query,
InstrumentedList,
PhoneNumber,
PhoneNumberType,
merge
)
class TestCase(object):
def setup_method(self, method):
self.engine = create_engine(
'postgres://postgres@localhost/sqlalchemy_utils_test'
)
self.Base = declarative_base()
self.create_models()
self.Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def teardown_method(self, method):
self.session.close_all()
self.Base.metadata.drop_all(self.engine)
self.engine.dispose()
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
phone_number = sa.Column(PhoneNumberType())
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
self.User = User
self.Category = Category
self.Article = Article
class DatabaseTestCase(object):
def create_models(self):
pass
def setup_method(self, method):
self.engine = create_engine(
'sqlite:///'
)
#self.engine.echo = True
self.Base = declarative_base()
self.create_models()
self.Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def teardown_method(self, method):
self.engine.dispose()
self.Base.metadata.drop_all(self.engine)
self.session.expunge_all()

View File

@@ -0,0 +1,14 @@
from tests import TestCase
class TestInstrumentedList(TestCase):
def test_any_returns_true_if_member_has_attr_defined(self):
category = self.Category()
category.articles.append(self.Article())
category.articles.append(self.Article(name=u'some name'))
assert category.articles.any('name')
def test_any_returns_false_if_no_member_has_attr_defined(self):
category = self.Category()
category.articles.append(self.Article())
assert not category.articles.any('name')

196
tests/test_merge.py Normal file
View File

@@ -0,0 +1,196 @@
import sqlalchemy as sa
from sqlalchemy_utils import merge
from tests import DatabaseTestCase
class TestMerge(DatabaseTestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def __repr__(self):
return 'User(%r)' % self.name
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.Unicode(255))
content = sa.Column(sa.UnicodeText)
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
self.User = User
self.BlogPost = BlogPost
def test_updates_foreign_keys(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
post = self.BlogPost(title=u'Some title', author=john)
post2 = self.BlogPost(title=u'Other title', author=jack)
self.session.add(john)
self.session.add(jack)
self.session.add(post)
self.session.add(post2)
self.session.commit()
merge(john, jack)
assert post.author == jack
assert post2.author == jack
def test_deletes_from_entity(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
self.session.add(john)
self.session.add(jack)
self.session.commit()
merge(john, jack)
assert john in self.session.deleted
class TestMergeManyToManyAssociations(DatabaseTestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
def __repr__(self):
return 'User(%r)' % self.name
team_member = sa.Table(
'team_member', self.Base.metadata,
sa.Column(
'user_id', sa.Integer,
sa.ForeignKey('user.id', ondelete='CASCADE'),
primary_key=True
),
sa.Column(
'team_id', sa.Integer,
sa.ForeignKey('team.id', ondelete='CASCADE'),
primary_key=True
)
)
class Team(self.Base):
__tablename__ = 'team'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
members = sa.orm.relationship(
User,
secondary=team_member,
backref='teams'
)
self.User = User
self.Team = Team
def test_when_association_only_exists_in_from_entity(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(john)
self.session.add(john)
self.session.add(jack)
self.session.commit()
merge(john, jack)
assert john not in team.members
assert jack in team.members
def test_when_association_exists_in_both(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(john)
team.members.append(jack)
self.session.add(john)
self.session.add(jack)
self.session.commit()
merge(john, jack)
assert john not in team.members
assert jack in team.members
count = self.session.execute(
'SELECT COUNT(1) FROM team_member'
).fetchone()[0]
assert count == 1
class TestMergeManyToManyAssociationObjects(DatabaseTestCase):
def create_models(self):
class Team(self.Base):
__tablename__ = 'team'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class TeamMember(self.Base):
__tablename__ = 'team_member'
user_id = sa.Column(
sa.Integer,
sa.ForeignKey(User.id, ondelete='CASCADE'),
primary_key=True
)
team_id = sa.Column(
sa.Integer,
sa.ForeignKey(Team.id, ondelete='CASCADE'),
primary_key=True
)
role = sa.Column(sa.Unicode(255))
team = sa.orm.relationship(
Team,
backref=sa.orm.backref(
'members',
cascade='all, delete-orphan'
),
primaryjoin=team_id == Team.id,
)
user = sa.orm.relationship(
User,
backref=sa.orm.backref(
'memberships',
cascade='all, delete-orphan'
),
primaryjoin=user_id == User.id,
)
self.User = User
self.TeamMember = TeamMember
self.Team = Team
def test_when_association_only_exists_in_from_entity(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(self.TeamMember(user=john))
self.session.add(john)
self.session.add(jack)
self.session.add(team)
self.session.commit()
merge(john, jack)
self.session.commit()
users = [member.user for member in team.members]
assert john not in users
assert jack in users
def test_when_association_exists_in_both(self):
john = self.User(name=u'John')
jack = self.User(name=u'Jack')
team = self.Team(name=u'Team')
team.members.append(self.TeamMember(user=john))
team.members.append(self.TeamMember(user=jack))
self.session.add(john)
self.session.add(jack)
self.session.add(team)
self.session.commit()
merge(john, jack)
users = [member.user for member in team.members]
assert john not in users
assert jack in users
assert self.session.query(self.TeamMember).count() == 1

View File

@@ -0,0 +1,65 @@
from tests import TestCase
from sqlalchemy_utils import PhoneNumber
class TestPhoneNumber(object):
def setup_method(self, method):
self.valid_phone_numbers = [
'040 1234567',
'+358 401234567',
'09 2501234',
'+358 92501234',
'0800 939393',
'09 4243 0456',
'0600 900 500'
]
self.invalid_phone_numbers = [
'abc',
'+040 1234567',
'0111234567',
'358'
]
def test_valid_phone_numbers(self):
for raw_number in self.valid_phone_numbers:
phone_number = PhoneNumber(raw_number, 'FI')
assert phone_number.is_valid_number()
def test_invalid_phone_numbers(self):
for raw_number in self.invalid_phone_numbers:
try:
phone_number = PhoneNumber(raw_number, 'FI')
assert not phone_number.is_valid_number()
except:
pass
def test_phone_number_attributes(self):
phone_number = PhoneNumber('+358401234567')
assert phone_number.e164 == u'+358401234567'
assert phone_number.international == u'+358 40 1234567'
assert phone_number.national == u'040 1234567'
class TestPhoneNumberType(TestCase):
def setup_method(self, method):
super(TestPhoneNumberType, self).setup_method(method)
self.phone_number = PhoneNumber(
'040 1234567',
'FI'
)
self.user = self.User()
self.user.name = u'Someone'
self.user.phone_number = self.phone_number
self.session.add(self.user)
self.session.commit()
def test_query_returns_phone_number_object(self):
queried_user = self.session.query(self.User).first()
assert queried_user.phone_number == self.phone_number
def test_phone_number_is_stored_as_string(self):
result = self.session.execute(
'SELECT phone_number FROM "user" WHERE id=:param',
{'param': self.user.id}
)
assert result.first()[0] == u'+358401234567'

View File

@@ -0,0 +1,47 @@
import sqlalchemy as sa
from sqlalchemy_utils import escape_like, sort_query
from tests import TestCase
class TestEscapeLike(TestCase):
def test_escapes_wildcards(self):
assert escape_like('_*%') == '*_***%'
class TestSortQuery(TestCase):
def test_without_sort_param_returns_the_query_object_untouched(self):
query = self.session.query(self.Article)
sorted_query = sort_query(query, '')
assert query == sorted_query
def test_sort_by_column_ascending(self):
query = sort_query(self.session.query(self.Article), 'name')
assert 'ORDER BY article.name ASC' in str(query)
def test_sort_by_column_descending(self):
query = sort_query(self.session.query(self.Article), '-name')
assert 'ORDER BY article.name DESC' in str(query)
def test_skips_unknown_columns(self):
query = self.session.query(self.Article)
sorted_query = sort_query(query, '-unknown')
assert query == sorted_query
def test_sort_by_calculated_value_ascending(self):
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, 'articles')
assert 'ORDER BY articles ASC' in str(query)
def test_sort_by_calculated_value_descending(self):
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, '-articles')
assert 'ORDER BY articles DESC' in str(query)
def test_sort_by_joined_table_column(self):
query = self.session.query(self.Article).join(self.Article.category)
sorted_query = sort_query(query, 'category-name')
assert 'category.name ASC' in str(sorted_query)