Added NumberRange types, refactored file structure
This commit is contained in:
@@ -1,372 +1,27 @@
|
||||
import phonenumbers
|
||||
from functools import wraps
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.orm import defer, object_session, mapperlib
|
||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.properties import ColumnProperty
|
||||
from sqlalchemy.sql.expression import desc, asc
|
||||
from sqlalchemy import types
|
||||
|
||||
|
||||
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
|
||||
'''
|
||||
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
||||
different phone number formats to attributes, so they can be easily used
|
||||
in templates. Phone number validation method is also implemented.
|
||||
|
||||
Takes the raw phone number and country code as params and parses them
|
||||
into a PhoneNumber object.
|
||||
|
||||
.. _Python phonenumbers library:
|
||||
https://github.com/daviddrysdale/python-phonenumbers
|
||||
|
||||
:param raw_number:
|
||||
String representation of the phone number.
|
||||
:param country_code:
|
||||
Country code of the phone number.
|
||||
'''
|
||||
def __init__(self, raw_number, country_code=None):
|
||||
self._phone_number = phonenumbers.parse(raw_number, country_code)
|
||||
super(PhoneNumber, self).__init__(
|
||||
country_code=self._phone_number.country_code,
|
||||
national_number=self._phone_number.national_number,
|
||||
extension=self._phone_number.extension,
|
||||
italian_leading_zero=self._phone_number.italian_leading_zero,
|
||||
raw_input=self._phone_number.raw_input,
|
||||
country_code_source=self._phone_number.country_code_source,
|
||||
preferred_domestic_carrier_code=
|
||||
self._phone_number.preferred_domestic_carrier_code
|
||||
)
|
||||
self.national = phonenumbers.format_number(
|
||||
self._phone_number,
|
||||
phonenumbers.PhoneNumberFormat.NATIONAL
|
||||
)
|
||||
self.international = phonenumbers.format_number(
|
||||
self._phone_number,
|
||||
phonenumbers.PhoneNumberFormat.INTERNATIONAL
|
||||
)
|
||||
self.e164 = phonenumbers.format_number(
|
||||
self._phone_number,
|
||||
phonenumbers.PhoneNumberFormat.E164
|
||||
)
|
||||
|
||||
def is_valid_number(self):
|
||||
return phonenumbers.is_valid_number(self._phone_number)
|
||||
|
||||
|
||||
class PhoneNumberType(types.TypeDecorator):
|
||||
"""
|
||||
Changes PhoneNumber objects to a string representation on the way in and
|
||||
changes them back to PhoneNumber objects on the way out. If E164 is used
|
||||
as storing format, no country code is needed for parsing the database
|
||||
value to PhoneNumber object.
|
||||
"""
|
||||
STORE_FORMAT = 'e164'
|
||||
impl = types.Unicode(20)
|
||||
|
||||
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
|
||||
super(PhoneNumberType, self).__init__(*args, **kwargs)
|
||||
self.country_code = country_code
|
||||
self.impl = types.Unicode(max_length)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return getattr(value, self.STORE_FORMAT)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return PhoneNumber(value, self.country_code)
|
||||
|
||||
|
||||
class InstrumentedList(_InstrumentedList):
|
||||
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
||||
additional functionality."""
|
||||
|
||||
def any(self, attr):
|
||||
return any(getattr(item, attr) for item in self)
|
||||
|
||||
def all(self, attr):
|
||||
return all(getattr(item, attr) for item in self)
|
||||
|
||||
|
||||
def instrumented_list(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return InstrumentedList([item for item in f(*args, **kwargs)])
|
||||
return wrapper
|
||||
|
||||
|
||||
def sort_query(query, sort):
|
||||
"""
|
||||
Applies an sql ORDER BY for given query. This function can be easily used
|
||||
with user-defined sorting.
|
||||
|
||||
The examples use the following model definition:
|
||||
|
||||
>>> import sqlalchemy as sa
|
||||
>>> from sqlalchemy import create_engine
|
||||
>>> from sqlalchemy.orm import sessionmaker
|
||||
>>> from sqlalchemy.ext.declarative import declarative_base
|
||||
>>> from sqlalchemy_utils import sort_query
|
||||
>>>
|
||||
>>>
|
||||
>>> engine = create_engine(
|
||||
... 'sqlite:///'
|
||||
... )
|
||||
>>> Base = declarative_base()
|
||||
>>> Session = sessionmaker(bind=engine)
|
||||
>>> session = Session()
|
||||
>>>
|
||||
>>> class Category(Base):
|
||||
... __tablename__ = 'category'
|
||||
... id = sa.Column(sa.Integer, primary_key=True)
|
||||
... name = sa.Column(sa.Unicode(255))
|
||||
>>>
|
||||
>>> class Article(Base):
|
||||
... __tablename__ = 'article'
|
||||
... id = sa.Column(sa.Integer, primary_key=True)
|
||||
... name = sa.Column(sa.Unicode(255))
|
||||
... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
...
|
||||
... category = sa.orm.relationship(
|
||||
... Category, primaryjoin=category_id == Category.id
|
||||
... )
|
||||
|
||||
|
||||
|
||||
1. Applying simple ascending sort
|
||||
|
||||
>>> query = session.query(Article)
|
||||
>>> query = sort_query(query, 'name')
|
||||
|
||||
2. Appying descending sort
|
||||
|
||||
>>> query = sort_query(query, '-name')
|
||||
|
||||
3. Applying sort to custom calculated label
|
||||
|
||||
>>> query = session.query(
|
||||
... Category, db.func.count(Article.id).label('articles')
|
||||
... )
|
||||
>>> query = sort_query(query, 'articles')
|
||||
|
||||
4. Applying sort to joined table column
|
||||
|
||||
>>> query = session.query(Article).join(Article.category)
|
||||
>>> query = sort_query(query, 'category-name')
|
||||
|
||||
|
||||
:param query: query to be modified
|
||||
:param sort: string that defines the label or column to sort the query by
|
||||
:param errors: whether or not to raise exceptions if unknown sort column
|
||||
is passed
|
||||
"""
|
||||
entities = [entity.entity_zero.class_ for entity in query._entities]
|
||||
for mapper in query._join_entities:
|
||||
if isinstance(mapper, Mapper):
|
||||
entities.append(mapper.class_)
|
||||
else:
|
||||
entities.append(mapper)
|
||||
|
||||
# get all label names for queries such as:
|
||||
# db.session.query(Category, db.func.count(Article.id).label('articles'))
|
||||
labels = []
|
||||
for entity in query._entities:
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||
labels.append(entity._label_name)
|
||||
|
||||
if not sort:
|
||||
return query
|
||||
|
||||
if sort[0] == '-':
|
||||
func = desc
|
||||
sort = sort[1:]
|
||||
else:
|
||||
func = asc
|
||||
|
||||
component = None
|
||||
parts = sort.split('-')
|
||||
if len(parts) > 1:
|
||||
component = parts[0]
|
||||
sort = parts[1]
|
||||
if sort in labels:
|
||||
return query.order_by(func(sort))
|
||||
|
||||
for entity in entities:
|
||||
table = entity.__table__
|
||||
if component and table.name != component:
|
||||
continue
|
||||
if sort in table.columns:
|
||||
try:
|
||||
attr = getattr(entity, sort)
|
||||
query = query.order_by(func(attr))
|
||||
except AttributeError:
|
||||
pass
|
||||
break
|
||||
return query
|
||||
|
||||
|
||||
def defer_except(query, columns):
|
||||
"""
|
||||
Deferred loads all columns in given query, except the ones given.
|
||||
|
||||
This function is very useful when working with models with myriad of
|
||||
columns and you want to deferred load many columns.
|
||||
|
||||
>>> from sqlalchemy_utils import defer_except
|
||||
>>> query = session.query(Article)
|
||||
>>> query = defer_except(Article, [Article.id, Article.name])
|
||||
|
||||
:param columns: columns not to deferred load
|
||||
"""
|
||||
model = query._entities[0].entity_zero.class_
|
||||
fields = set(model._sa_class_manager.values())
|
||||
for field in fields:
|
||||
property_ = field.property
|
||||
if isinstance(property_, ColumnProperty):
|
||||
column = property_.columns[0]
|
||||
if column.name not in columns:
|
||||
query = query.options(defer(property_.key))
|
||||
return query
|
||||
|
||||
|
||||
def escape_like(string, escape_char='*'):
|
||||
"""
|
||||
Escapes the string paremeter used in SQL LIKE expressions
|
||||
|
||||
>>> from sqlalchemy_utils import escape_like
|
||||
>>> query = session.query(User).filter(
|
||||
... User.name.ilike(escape_like('John'))
|
||||
... )
|
||||
|
||||
|
||||
:param string: a string to escape
|
||||
:param escape_char: escape character
|
||||
"""
|
||||
return (
|
||||
string
|
||||
.replace(escape_char, escape_char * 2)
|
||||
.replace('%', escape_char + '%')
|
||||
.replace('_', escape_char + '_')
|
||||
)
|
||||
|
||||
|
||||
def dependent_foreign_keys(model_class):
|
||||
"""
|
||||
Returns dependent foreign keys as dicts for given model class.
|
||||
|
||||
** Experimental function **
|
||||
"""
|
||||
session = object_session(model_class)
|
||||
|
||||
engine = session.bind
|
||||
inspector = reflection.Inspector.from_engine(engine)
|
||||
table_names = inspector.get_table_names()
|
||||
|
||||
dependent_foreign_keys = {}
|
||||
|
||||
for table_name in table_names:
|
||||
fks = inspector.get_foreign_keys(table_name)
|
||||
if fks:
|
||||
dependent_foreign_keys[table_name] = []
|
||||
for fk in fks:
|
||||
if fk['referred_table'] == model_class.__tablename__:
|
||||
dependent_foreign_keys[table_name].append(fk)
|
||||
return dependent_foreign_keys
|
||||
|
||||
|
||||
class Merger(object):
|
||||
def memory_merge(self, session, table_name, old_values, new_values):
|
||||
# try to fetch mappers for given table and update in memory objects as
|
||||
# well as database table
|
||||
found = False
|
||||
for mapper in mapperlib._mapper_registry:
|
||||
class_ = mapper.class_
|
||||
if table_name == class_.__table__.name:
|
||||
try:
|
||||
(
|
||||
session.query(mapper.class_)
|
||||
.filter_by(**old_values)
|
||||
.update(
|
||||
new_values,
|
||||
'fetch'
|
||||
)
|
||||
)
|
||||
except sa.exc.IntegrityError:
|
||||
pass
|
||||
found = True
|
||||
return found
|
||||
|
||||
def raw_merge(self, session, table, old_values, new_values):
|
||||
conditions = []
|
||||
for key, value in old_values.items():
|
||||
conditions.append(getattr(table.c, key) == value)
|
||||
sql = (
|
||||
table
|
||||
.update()
|
||||
.where(sa.and_(
|
||||
*conditions
|
||||
))
|
||||
.values(
|
||||
new_values
|
||||
)
|
||||
)
|
||||
try:
|
||||
session.execute(sql)
|
||||
except sa.exc.IntegrityError:
|
||||
pass
|
||||
|
||||
def merge_update(self, table_name, from_, to, foreign_key):
|
||||
session = object_session(from_)
|
||||
constrained_columns = foreign_key['constrained_columns']
|
||||
referred_columns = foreign_key['referred_columns']
|
||||
metadata = from_.metadata
|
||||
table = metadata.tables[table_name]
|
||||
|
||||
new_values = {}
|
||||
for index, column in enumerate(constrained_columns):
|
||||
new_values[column] = getattr(
|
||||
to, referred_columns[index]
|
||||
)
|
||||
|
||||
old_values = {}
|
||||
for index, column in enumerate(constrained_columns):
|
||||
old_values[column] = getattr(
|
||||
from_, referred_columns[index]
|
||||
)
|
||||
|
||||
if not self.memory_merge(session, table_name, old_values, new_values):
|
||||
self.raw_merge(session, table, old_values, new_values)
|
||||
|
||||
def __call__(self, from_, to):
|
||||
"""
|
||||
Merges entity into another entity. After merging deletes the from_
|
||||
argument entity.
|
||||
"""
|
||||
if from_.__tablename__ != to.__tablename__:
|
||||
raise Exception()
|
||||
|
||||
session = object_session(from_)
|
||||
foreign_keys = dependent_foreign_keys(from_)
|
||||
|
||||
for table_name in foreign_keys:
|
||||
for foreign_key in foreign_keys[table_name]:
|
||||
self.merge_update(table_name, from_, to, foreign_key)
|
||||
|
||||
session.delete(from_)
|
||||
|
||||
|
||||
def merge(from_, to, merger=Merger):
|
||||
"""
|
||||
Merges entity into another entity. After merging deletes the from_ argument
|
||||
entity.
|
||||
|
||||
After merging the from_ entity is deleted from database.
|
||||
|
||||
:param from_: an entity to merge into another entity
|
||||
:param to: an entity to merge another entity into
|
||||
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
|
||||
class
|
||||
"""
|
||||
return Merger()(from_, to)
|
||||
from .functions import sort_query, defer_except, escape_like
|
||||
from .merge import merge, Merger
|
||||
from .types import (
|
||||
instrumented_list,
|
||||
InstrumentedList,
|
||||
PhoneNumber,
|
||||
PhoneNumberType,
|
||||
NumberRange,
|
||||
NumberRangeRawType,
|
||||
NumberRangeType
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
sort_query,
|
||||
defer_except,
|
||||
escape_like,
|
||||
instrumented_list,
|
||||
merge,
|
||||
InstrumentedList,
|
||||
Merger,
|
||||
NumberRange,
|
||||
NumberRangeRawType,
|
||||
NumberRangeType,
|
||||
PhoneNumber,
|
||||
PhoneNumberType,
|
||||
)
|
||||
|
159
sqlalchemy_utils/functions.py
Normal file
159
sqlalchemy_utils/functions.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.properties import ColumnProperty
|
||||
from sqlalchemy.sql.expression import desc, asc
|
||||
|
||||
|
||||
def sort_query(query, sort):
|
||||
"""
|
||||
Applies an sql ORDER BY for given query. This function can be easily used
|
||||
with user-defined sorting.
|
||||
|
||||
The examples use the following model definition:
|
||||
|
||||
>>> import sqlalchemy as sa
|
||||
>>> from sqlalchemy import create_engine
|
||||
>>> from sqlalchemy.orm import sessionmaker
|
||||
>>> from sqlalchemy.ext.declarative import declarative_base
|
||||
>>> from sqlalchemy_utils import sort_query
|
||||
>>>
|
||||
>>>
|
||||
>>> engine = create_engine(
|
||||
... 'sqlite:///'
|
||||
... )
|
||||
>>> Base = declarative_base()
|
||||
>>> Session = sessionmaker(bind=engine)
|
||||
>>> session = Session()
|
||||
>>>
|
||||
>>> class Category(Base):
|
||||
... __tablename__ = 'category'
|
||||
... id = sa.Column(sa.Integer, primary_key=True)
|
||||
... name = sa.Column(sa.Unicode(255))
|
||||
>>>
|
||||
>>> class Article(Base):
|
||||
... __tablename__ = 'article'
|
||||
... id = sa.Column(sa.Integer, primary_key=True)
|
||||
... name = sa.Column(sa.Unicode(255))
|
||||
... category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
...
|
||||
... category = sa.orm.relationship(
|
||||
... Category, primaryjoin=category_id == Category.id
|
||||
... )
|
||||
|
||||
|
||||
|
||||
1. Applying simple ascending sort
|
||||
|
||||
>>> query = session.query(Article)
|
||||
>>> query = sort_query(query, 'name')
|
||||
|
||||
2. Appying descending sort
|
||||
|
||||
>>> query = sort_query(query, '-name')
|
||||
|
||||
3. Applying sort to custom calculated label
|
||||
|
||||
>>> query = session.query(
|
||||
... Category, db.func.count(Article.id).label('articles')
|
||||
... )
|
||||
>>> query = sort_query(query, 'articles')
|
||||
|
||||
4. Applying sort to joined table column
|
||||
|
||||
>>> query = session.query(Article).join(Article.category)
|
||||
>>> query = sort_query(query, 'category-name')
|
||||
|
||||
|
||||
:param query: query to be modified
|
||||
:param sort: string that defines the label or column to sort the query by
|
||||
:param errors: whether or not to raise exceptions if unknown sort column
|
||||
is passed
|
||||
"""
|
||||
entities = [entity.entity_zero.class_ for entity in query._entities]
|
||||
for mapper in query._join_entities:
|
||||
if isinstance(mapper, Mapper):
|
||||
entities.append(mapper.class_)
|
||||
else:
|
||||
entities.append(mapper)
|
||||
|
||||
# get all label names for queries such as:
|
||||
# db.session.query(Category, db.func.count(Article.id).label('articles'))
|
||||
labels = []
|
||||
for entity in query._entities:
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||
labels.append(entity._label_name)
|
||||
|
||||
if not sort:
|
||||
return query
|
||||
|
||||
if sort[0] == '-':
|
||||
func = desc
|
||||
sort = sort[1:]
|
||||
else:
|
||||
func = asc
|
||||
|
||||
component = None
|
||||
parts = sort.split('-')
|
||||
if len(parts) > 1:
|
||||
component = parts[0]
|
||||
sort = parts[1]
|
||||
if sort in labels:
|
||||
return query.order_by(func(sort))
|
||||
|
||||
for entity in entities:
|
||||
table = entity.__table__
|
||||
if component and table.name != component:
|
||||
continue
|
||||
if sort in table.columns:
|
||||
try:
|
||||
attr = getattr(entity, sort)
|
||||
query = query.order_by(func(attr))
|
||||
except AttributeError:
|
||||
pass
|
||||
break
|
||||
return query
|
||||
|
||||
|
||||
def defer_except(query, columns):
|
||||
"""
|
||||
Deferred loads all columns in given query, except the ones given.
|
||||
|
||||
This function is very useful when working with models with myriad of
|
||||
columns and you want to deferred load many columns.
|
||||
|
||||
>>> from sqlalchemy_utils import defer_except
|
||||
>>> query = session.query(Article)
|
||||
>>> query = defer_except(Article, [Article.id, Article.name])
|
||||
|
||||
:param columns: columns not to deferred load
|
||||
"""
|
||||
model = query._entities[0].entity_zero.class_
|
||||
fields = set(model._sa_class_manager.values())
|
||||
for field in fields:
|
||||
property_ = field.property
|
||||
if isinstance(property_, ColumnProperty):
|
||||
column = property_.columns[0]
|
||||
if column.name not in columns:
|
||||
query = query.options(defer(property_.key))
|
||||
return query
|
||||
|
||||
|
||||
def escape_like(string, escape_char='*'):
|
||||
"""
|
||||
Escapes the string paremeter used in SQL LIKE expressions
|
||||
|
||||
>>> from sqlalchemy_utils import escape_like
|
||||
>>> query = session.query(User).filter(
|
||||
... User.name.ilike(escape_like('John'))
|
||||
... )
|
||||
|
||||
|
||||
:param string: a string to escape
|
||||
:param escape_char: escape character
|
||||
"""
|
||||
return (
|
||||
string
|
||||
.replace(escape_char, escape_char * 2)
|
||||
.replace('%', escape_char + '%')
|
||||
.replace('_', escape_char + '_')
|
||||
)
|
123
sqlalchemy_utils/merge.py
Normal file
123
sqlalchemy_utils/merge.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.orm import object_session, mapperlib
|
||||
|
||||
|
||||
def dependent_foreign_keys(model_class):
|
||||
"""
|
||||
Returns dependent foreign keys as dicts for given model class.
|
||||
|
||||
** Experimental function **
|
||||
"""
|
||||
session = object_session(model_class)
|
||||
|
||||
engine = session.bind
|
||||
inspector = reflection.Inspector.from_engine(engine)
|
||||
table_names = inspector.get_table_names()
|
||||
|
||||
dependent_foreign_keys = {}
|
||||
|
||||
for table_name in table_names:
|
||||
fks = inspector.get_foreign_keys(table_name)
|
||||
if fks:
|
||||
dependent_foreign_keys[table_name] = []
|
||||
for fk in fks:
|
||||
if fk['referred_table'] == model_class.__tablename__:
|
||||
dependent_foreign_keys[table_name].append(fk)
|
||||
return dependent_foreign_keys
|
||||
|
||||
|
||||
class Merger(object):
|
||||
def memory_merge(self, session, table_name, old_values, new_values):
|
||||
# try to fetch mappers for given table and update in memory objects as
|
||||
# well as database table
|
||||
found = False
|
||||
for mapper in mapperlib._mapper_registry:
|
||||
class_ = mapper.class_
|
||||
if table_name == class_.__table__.name:
|
||||
try:
|
||||
(
|
||||
session.query(mapper.class_)
|
||||
.filter_by(**old_values)
|
||||
.update(
|
||||
new_values,
|
||||
'fetch'
|
||||
)
|
||||
)
|
||||
except sa.exc.IntegrityError:
|
||||
pass
|
||||
found = True
|
||||
return found
|
||||
|
||||
def raw_merge(self, session, table, old_values, new_values):
|
||||
conditions = []
|
||||
for key, value in old_values.items():
|
||||
conditions.append(getattr(table.c, key) == value)
|
||||
sql = (
|
||||
table
|
||||
.update()
|
||||
.where(sa.and_(
|
||||
*conditions
|
||||
))
|
||||
.values(
|
||||
new_values
|
||||
)
|
||||
)
|
||||
try:
|
||||
session.execute(sql)
|
||||
except sa.exc.IntegrityError:
|
||||
pass
|
||||
|
||||
def merge_update(self, table_name, from_, to, foreign_key):
|
||||
session = object_session(from_)
|
||||
constrained_columns = foreign_key['constrained_columns']
|
||||
referred_columns = foreign_key['referred_columns']
|
||||
metadata = from_.metadata
|
||||
table = metadata.tables[table_name]
|
||||
|
||||
new_values = {}
|
||||
for index, column in enumerate(constrained_columns):
|
||||
new_values[column] = getattr(
|
||||
to, referred_columns[index]
|
||||
)
|
||||
|
||||
old_values = {}
|
||||
for index, column in enumerate(constrained_columns):
|
||||
old_values[column] = getattr(
|
||||
from_, referred_columns[index]
|
||||
)
|
||||
|
||||
if not self.memory_merge(session, table_name, old_values, new_values):
|
||||
self.raw_merge(session, table, old_values, new_values)
|
||||
|
||||
def __call__(self, from_, to):
|
||||
"""
|
||||
Merges entity into another entity. After merging deletes the from_
|
||||
argument entity.
|
||||
"""
|
||||
if from_.__tablename__ != to.__tablename__:
|
||||
raise Exception()
|
||||
|
||||
session = object_session(from_)
|
||||
foreign_keys = dependent_foreign_keys(from_)
|
||||
|
||||
for table_name in foreign_keys:
|
||||
for foreign_key in foreign_keys[table_name]:
|
||||
self.merge_update(table_name, from_, to, foreign_key)
|
||||
|
||||
session.delete(from_)
|
||||
|
||||
|
||||
def merge(from_, to, merger=Merger):
|
||||
"""
|
||||
Merges entity into another entity. After merging deletes the from_ argument
|
||||
entity.
|
||||
|
||||
After merging the from_ entity is deleted from database.
|
||||
|
||||
:param from_: an entity to merge into another entity
|
||||
:param to: an entity to merge another entity into
|
||||
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
|
||||
class
|
||||
"""
|
||||
return Merger()(from_, to)
|
179
sqlalchemy_utils/types.py
Normal file
179
sqlalchemy_utils/types.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import phonenumbers
|
||||
from functools import wraps
|
||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||
from sqlalchemy import types
|
||||
|
||||
|
||||
class PhoneNumber(phonenumbers.phonenumber.PhoneNumber):
|
||||
'''
|
||||
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
||||
different phone number formats to attributes, so they can be easily used
|
||||
in templates. Phone number validation method is also implemented.
|
||||
|
||||
Takes the raw phone number and country code as params and parses them
|
||||
into a PhoneNumber object.
|
||||
|
||||
.. _Python phonenumbers library:
|
||||
https://github.com/daviddrysdale/python-phonenumbers
|
||||
|
||||
:param raw_number:
|
||||
String representation of the phone number.
|
||||
:param country_code:
|
||||
Country code of the phone number.
|
||||
'''
|
||||
def __init__(self, raw_number, country_code=None):
|
||||
self._phone_number = phonenumbers.parse(raw_number, country_code)
|
||||
super(PhoneNumber, self).__init__(
|
||||
country_code=self._phone_number.country_code,
|
||||
national_number=self._phone_number.national_number,
|
||||
extension=self._phone_number.extension,
|
||||
italian_leading_zero=self._phone_number.italian_leading_zero,
|
||||
raw_input=self._phone_number.raw_input,
|
||||
country_code_source=self._phone_number.country_code_source,
|
||||
preferred_domestic_carrier_code=
|
||||
self._phone_number.preferred_domestic_carrier_code
|
||||
)
|
||||
self.national = phonenumbers.format_number(
|
||||
self._phone_number,
|
||||
phonenumbers.PhoneNumberFormat.NATIONAL
|
||||
)
|
||||
self.international = phonenumbers.format_number(
|
||||
self._phone_number,
|
||||
phonenumbers.PhoneNumberFormat.INTERNATIONAL
|
||||
)
|
||||
self.e164 = phonenumbers.format_number(
|
||||
self._phone_number,
|
||||
phonenumbers.PhoneNumberFormat.E164
|
||||
)
|
||||
|
||||
def is_valid_number(self):
|
||||
return phonenumbers.is_valid_number(self._phone_number)
|
||||
|
||||
|
||||
class PhoneNumberType(types.TypeDecorator):
|
||||
"""
|
||||
Changes PhoneNumber objects to a string representation on the way in and
|
||||
changes them back to PhoneNumber objects on the way out. If E164 is used
|
||||
as storing format, no country code is needed for parsing the database
|
||||
value to PhoneNumber object.
|
||||
"""
|
||||
STORE_FORMAT = 'e164'
|
||||
impl = types.Unicode(20)
|
||||
|
||||
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
|
||||
super(PhoneNumberType, self).__init__(*args, **kwargs)
|
||||
self.country_code = country_code
|
||||
self.impl = types.Unicode(max_length)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return getattr(value, self.STORE_FORMAT)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return PhoneNumber(value, self.country_code)
|
||||
|
||||
|
||||
class NumberRangeRawType(types.UserDefinedType):
|
||||
"""
|
||||
Raw number range type, only supports PostgreSQL for now.
|
||||
"""
|
||||
def get_col_spec(self):
|
||||
return 'int4range'
|
||||
|
||||
|
||||
class NumberRangeType(types.TypeDecorator):
|
||||
impl = NumberRangeRawType
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return NumberRange.from_normalized_str(value)
|
||||
|
||||
|
||||
class NumberRange(object):
|
||||
def __init__(self, min_value, max_value):
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
@classmethod
|
||||
def from_normalized_str(cls, value):
|
||||
if value is not None:
|
||||
values = value[1:-1].split(',')
|
||||
min_value, max_value = map(
|
||||
lambda a: int(a.strip()), values
|
||||
)
|
||||
|
||||
if value[0] == '(':
|
||||
min_value += 1
|
||||
|
||||
if value[1] == ')':
|
||||
max_value -= 1
|
||||
|
||||
return cls(min_value, max_value)
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value):
|
||||
if value is not None:
|
||||
values = value.split('-')
|
||||
min_value, max_value = map(
|
||||
lambda a: int(a.strip()), values
|
||||
)
|
||||
return cls(min_value, max_value)
|
||||
|
||||
def __repr__(self):
|
||||
return 'NumberRange(%r, %r)' % (self.min_value, self.max_value)
|
||||
|
||||
def __str__(self):
|
||||
return '[%s, %s]' % (self.min_value, self.max_value)
|
||||
|
||||
def __add__(self, other):
|
||||
try:
|
||||
return NumberRange(
|
||||
self.min_value + other.min_value,
|
||||
self.max_value + other.max_value
|
||||
)
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
def __iadd__(self, other):
|
||||
try:
|
||||
self.min_value += other.min_value
|
||||
self.max_value += other.max_value
|
||||
return self
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
def __sub__(self, other):
|
||||
try:
|
||||
return NumberRange(
|
||||
self.min_value - other.min_value,
|
||||
self.max_value - other.max_value
|
||||
)
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
def __isub__(self, other):
|
||||
try:
|
||||
self.min_value -= other.min_value
|
||||
self.max_value -= other.max_value
|
||||
return self
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class InstrumentedList(_InstrumentedList):
|
||||
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
||||
additional functionality."""
|
||||
|
||||
def any(self, attr):
|
||||
return any(getattr(item, attr) for item in self)
|
||||
|
||||
def all(self, attr):
|
||||
return all(getattr(item, attr) for item in self)
|
||||
|
||||
|
||||
def instrumented_list(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return InstrumentedList([item for item in f(*args, **kwargs)])
|
||||
return wrapper
|
397
tests.py
397
tests.py
@@ -1,397 +0,0 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import (
|
||||
escape_like,
|
||||
sort_query,
|
||||
InstrumentedList,
|
||||
PhoneNumber,
|
||||
PhoneNumberType,
|
||||
merge
|
||||
)
|
||||
|
||||
|
||||
class TestCase(object):
|
||||
|
||||
def setup_method(self, method):
|
||||
self.engine = create_engine('sqlite:///:memory:')
|
||||
self.Base = declarative_base()
|
||||
|
||||
self.create_models()
|
||||
self.Base.metadata.create_all(self.engine)
|
||||
|
||||
Session = sessionmaker(bind=self.engine)
|
||||
self.session = Session()
|
||||
|
||||
def teardown_method(self, method):
|
||||
self.session.close_all()
|
||||
self.Base.metadata.drop_all(self.engine)
|
||||
self.engine.dispose()
|
||||
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
phone_number = sa.Column(PhoneNumberType())
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
|
||||
category = sa.orm.relationship(
|
||||
Category,
|
||||
primaryjoin=category_id == Category.id,
|
||||
backref=sa.orm.backref(
|
||||
'articles',
|
||||
collection_class=InstrumentedList
|
||||
)
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Category = Category
|
||||
self.Article = Article
|
||||
|
||||
|
||||
class TestInstrumentedList(TestCase):
|
||||
def test_any_returns_true_if_member_has_attr_defined(self):
|
||||
category = self.Category()
|
||||
category.articles.append(self.Article())
|
||||
category.articles.append(self.Article(name=u'some name'))
|
||||
assert category.articles.any('name')
|
||||
|
||||
def test_any_returns_false_if_no_member_has_attr_defined(self):
|
||||
category = self.Category()
|
||||
category.articles.append(self.Article())
|
||||
assert not category.articles.any('name')
|
||||
|
||||
|
||||
class TestEscapeLike(TestCase):
|
||||
def test_escapes_wildcards(self):
|
||||
assert escape_like('_*%') == '*_***%'
|
||||
|
||||
|
||||
class TestSortQuery(TestCase):
|
||||
def test_without_sort_param_returns_the_query_object_untouched(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_column_ascending(self):
|
||||
query = sort_query(self.session.query(self.Article), 'name')
|
||||
assert 'ORDER BY article.name ASC' in str(query)
|
||||
|
||||
def test_sort_by_column_descending(self):
|
||||
query = sort_query(self.session.query(self.Article), '-name')
|
||||
assert 'ORDER BY article.name DESC' in str(query)
|
||||
|
||||
def test_skips_unknown_columns(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '-unknown')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_calculated_value_ascending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, 'articles')
|
||||
assert 'ORDER BY articles ASC' in str(query)
|
||||
|
||||
def test_sort_by_calculated_value_descending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, '-articles')
|
||||
assert 'ORDER BY articles DESC' in str(query)
|
||||
|
||||
def test_sort_by_joined_table_column(self):
|
||||
query = self.session.query(self.Article).join(self.Article.category)
|
||||
sorted_query = sort_query(query, 'category-name')
|
||||
assert 'category.name ASC' in str(sorted_query)
|
||||
|
||||
|
||||
class TestPhoneNumber(object):
|
||||
def setup_method(self, method):
|
||||
self.valid_phone_numbers = [
|
||||
'040 1234567',
|
||||
'+358 401234567',
|
||||
'09 2501234',
|
||||
'+358 92501234',
|
||||
'0800 939393',
|
||||
'09 4243 0456',
|
||||
'0600 900 500'
|
||||
]
|
||||
self.invalid_phone_numbers = [
|
||||
'abc',
|
||||
'+040 1234567',
|
||||
'0111234567',
|
||||
'358'
|
||||
]
|
||||
|
||||
def test_valid_phone_numbers(self):
|
||||
for raw_number in self.valid_phone_numbers:
|
||||
phone_number = PhoneNumber(raw_number, 'FI')
|
||||
assert phone_number.is_valid_number()
|
||||
|
||||
def test_invalid_phone_numbers(self):
|
||||
for raw_number in self.invalid_phone_numbers:
|
||||
try:
|
||||
phone_number = PhoneNumber(raw_number, 'FI')
|
||||
assert not phone_number.is_valid_number()
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_phone_number_attributes(self):
|
||||
phone_number = PhoneNumber('+358401234567')
|
||||
assert phone_number.e164 == u'+358401234567'
|
||||
assert phone_number.international == u'+358 40 1234567'
|
||||
assert phone_number.national == u'040 1234567'
|
||||
|
||||
|
||||
class TestPhoneNumberType(TestCase):
|
||||
def setup_method(self, method):
|
||||
super(TestPhoneNumberType, self).setup_method(method)
|
||||
self.phone_number = PhoneNumber(
|
||||
'040 1234567',
|
||||
'FI'
|
||||
)
|
||||
self.user = self.User()
|
||||
self.user.name = u'Someone'
|
||||
self.user.phone_number = self.phone_number
|
||||
self.session.add(self.user)
|
||||
self.session.commit()
|
||||
|
||||
def test_query_returns_phone_number_object(self):
|
||||
queried_user = self.session.query(self.User).first()
|
||||
assert queried_user.phone_number == self.phone_number
|
||||
|
||||
def test_phone_number_is_stored_as_string(self):
|
||||
result = self.session.execute(
|
||||
'SELECT phone_number FROM user WHERE id=:param',
|
||||
{'param': self.user.id}
|
||||
)
|
||||
assert result.first()[0] == u'+358401234567'
|
||||
|
||||
|
||||
class DatabaseTestCase(object):
|
||||
def create_models(self):
|
||||
pass
|
||||
|
||||
def setup_method(self, method):
|
||||
self.engine = create_engine(
|
||||
'sqlite:///'
|
||||
)
|
||||
#self.engine.echo = True
|
||||
self.Base = declarative_base()
|
||||
|
||||
self.create_models()
|
||||
self.Base.metadata.create_all(self.engine)
|
||||
Session = sessionmaker(bind=self.engine)
|
||||
self.session = Session()
|
||||
|
||||
def teardown_method(self, method):
|
||||
self.engine.dispose()
|
||||
self.Base.metadata.drop_all(self.engine)
|
||||
self.session.expunge_all()
|
||||
|
||||
|
||||
class TestMerge(DatabaseTestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.name
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.Unicode(255))
|
||||
content = sa.Column(sa.UnicodeText)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
self.User = User
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def test_updates_foreign_keys(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
post = self.BlogPost(title=u'Some title', author=john)
|
||||
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(post)
|
||||
self.session.add(post2)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert post.author == jack
|
||||
assert post2.author == jack
|
||||
|
||||
def test_deletes_from_entity(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert john in self.session.deleted
|
||||
|
||||
|
||||
class TestMergeManyToManyAssociations(DatabaseTestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.name
|
||||
|
||||
team_member = sa.Table(
|
||||
'team_member', self.Base.metadata,
|
||||
sa.Column(
|
||||
'user_id', sa.Integer,
|
||||
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
'team_id', sa.Integer,
|
||||
sa.ForeignKey('team.id', ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
)
|
||||
|
||||
class Team(self.Base):
|
||||
__tablename__ = 'team'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
members = sa.orm.relationship(
|
||||
User,
|
||||
secondary=team_member,
|
||||
backref='teams'
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Team = Team
|
||||
|
||||
def test_when_association_only_exists_in_from_entity(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(john)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert john not in team.members
|
||||
assert jack in team.members
|
||||
|
||||
def test_when_association_exists_in_both(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(john)
|
||||
team.members.append(jack)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert john not in team.members
|
||||
assert jack in team.members
|
||||
count = self.session.execute(
|
||||
'SELECT COUNT(1) FROM team_member'
|
||||
).fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestMergeManyToManyAssociationObjects(DatabaseTestCase):
|
||||
def create_models(self):
|
||||
class Team(self.Base):
|
||||
__tablename__ = 'team'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class TeamMember(self.Base):
|
||||
__tablename__ = 'team_member'
|
||||
user_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey(User.id, ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
team_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey(Team.id, ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
role = sa.Column(sa.Unicode(255))
|
||||
team = sa.orm.relationship(
|
||||
Team,
|
||||
backref=sa.orm.backref(
|
||||
'members',
|
||||
cascade='all, delete-orphan'
|
||||
),
|
||||
primaryjoin=team_id == Team.id,
|
||||
)
|
||||
user = sa.orm.relationship(
|
||||
User,
|
||||
backref=sa.orm.backref(
|
||||
'memberships',
|
||||
cascade='all, delete-orphan'
|
||||
),
|
||||
primaryjoin=user_id == User.id,
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.TeamMember = TeamMember
|
||||
self.Team = Team
|
||||
|
||||
def test_when_association_only_exists_in_from_entity(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(self.TeamMember(user=john))
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(team)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
self.session.commit()
|
||||
users = [member.user for member in team.members]
|
||||
assert john not in users
|
||||
assert jack in users
|
||||
|
||||
def test_when_association_exists_in_both(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(self.TeamMember(user=john))
|
||||
team.members.append(self.TeamMember(user=jack))
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(team)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
users = [member.user for member in team.members]
|
||||
assert john not in users
|
||||
assert jack in users
|
||||
assert self.session.query(self.TeamMember).count() == 1
|
86
tests/__init__.py
Normal file
86
tests/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import (
|
||||
escape_like,
|
||||
sort_query,
|
||||
InstrumentedList,
|
||||
PhoneNumber,
|
||||
PhoneNumberType,
|
||||
merge
|
||||
)
|
||||
|
||||
|
||||
class TestCase(object):
|
||||
def setup_method(self, method):
|
||||
self.engine = create_engine(
|
||||
'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
)
|
||||
self.Base = declarative_base()
|
||||
|
||||
self.create_models()
|
||||
self.Base.metadata.create_all(self.engine)
|
||||
|
||||
Session = sessionmaker(bind=self.engine)
|
||||
self.session = Session()
|
||||
|
||||
def teardown_method(self, method):
|
||||
self.session.close_all()
|
||||
self.Base.metadata.drop_all(self.engine)
|
||||
self.engine.dispose()
|
||||
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
phone_number = sa.Column(PhoneNumberType())
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
|
||||
category = sa.orm.relationship(
|
||||
Category,
|
||||
primaryjoin=category_id == Category.id,
|
||||
backref=sa.orm.backref(
|
||||
'articles',
|
||||
collection_class=InstrumentedList
|
||||
)
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Category = Category
|
||||
self.Article = Article
|
||||
|
||||
|
||||
class DatabaseTestCase(object):
|
||||
def create_models(self):
|
||||
pass
|
||||
|
||||
def setup_method(self, method):
|
||||
self.engine = create_engine(
|
||||
'sqlite:///'
|
||||
)
|
||||
#self.engine.echo = True
|
||||
self.Base = declarative_base()
|
||||
|
||||
self.create_models()
|
||||
self.Base.metadata.create_all(self.engine)
|
||||
Session = sessionmaker(bind=self.engine)
|
||||
self.session = Session()
|
||||
|
||||
def teardown_method(self, method):
|
||||
self.engine.dispose()
|
||||
self.Base.metadata.drop_all(self.engine)
|
||||
self.session.expunge_all()
|
14
tests/test_instrumented_list.py
Normal file
14
tests/test_instrumented_list.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestInstrumentedList(TestCase):
|
||||
def test_any_returns_true_if_member_has_attr_defined(self):
|
||||
category = self.Category()
|
||||
category.articles.append(self.Article())
|
||||
category.articles.append(self.Article(name=u'some name'))
|
||||
assert category.articles.any('name')
|
||||
|
||||
def test_any_returns_false_if_no_member_has_attr_defined(self):
|
||||
category = self.Category()
|
||||
category.articles.append(self.Article())
|
||||
assert not category.articles.any('name')
|
196
tests/test_merge.py
Normal file
196
tests/test_merge.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import merge
|
||||
|
||||
from tests import DatabaseTestCase
|
||||
|
||||
|
||||
class TestMerge(DatabaseTestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.name
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.Unicode(255))
|
||||
content = sa.Column(sa.UnicodeText)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
self.User = User
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def test_updates_foreign_keys(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
post = self.BlogPost(title=u'Some title', author=john)
|
||||
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(post)
|
||||
self.session.add(post2)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert post.author == jack
|
||||
assert post2.author == jack
|
||||
|
||||
def test_deletes_from_entity(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert john in self.session.deleted
|
||||
|
||||
|
||||
class TestMergeManyToManyAssociations(DatabaseTestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(%r)' % self.name
|
||||
|
||||
team_member = sa.Table(
|
||||
'team_member', self.Base.metadata,
|
||||
sa.Column(
|
||||
'user_id', sa.Integer,
|
||||
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
'team_id', sa.Integer,
|
||||
sa.ForeignKey('team.id', ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
)
|
||||
|
||||
class Team(self.Base):
|
||||
__tablename__ = 'team'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
members = sa.orm.relationship(
|
||||
User,
|
||||
secondary=team_member,
|
||||
backref='teams'
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Team = Team
|
||||
|
||||
def test_when_association_only_exists_in_from_entity(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(john)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert john not in team.members
|
||||
assert jack in team.members
|
||||
|
||||
def test_when_association_exists_in_both(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(john)
|
||||
team.members.append(jack)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert john not in team.members
|
||||
assert jack in team.members
|
||||
count = self.session.execute(
|
||||
'SELECT COUNT(1) FROM team_member'
|
||||
).fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestMergeManyToManyAssociationObjects(DatabaseTestCase):
|
||||
def create_models(self):
|
||||
class Team(self.Base):
|
||||
__tablename__ = 'team'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class TeamMember(self.Base):
|
||||
__tablename__ = 'team_member'
|
||||
user_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey(User.id, ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
team_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey(Team.id, ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
role = sa.Column(sa.Unicode(255))
|
||||
team = sa.orm.relationship(
|
||||
Team,
|
||||
backref=sa.orm.backref(
|
||||
'members',
|
||||
cascade='all, delete-orphan'
|
||||
),
|
||||
primaryjoin=team_id == Team.id,
|
||||
)
|
||||
user = sa.orm.relationship(
|
||||
User,
|
||||
backref=sa.orm.backref(
|
||||
'memberships',
|
||||
cascade='all, delete-orphan'
|
||||
),
|
||||
primaryjoin=user_id == User.id,
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.TeamMember = TeamMember
|
||||
self.Team = Team
|
||||
|
||||
def test_when_association_only_exists_in_from_entity(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(self.TeamMember(user=john))
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(team)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
self.session.commit()
|
||||
users = [member.user for member in team.members]
|
||||
assert john not in users
|
||||
assert jack in users
|
||||
|
||||
def test_when_association_exists_in_both(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
team.members.append(self.TeamMember(user=john))
|
||||
team.members.append(self.TeamMember(user=jack))
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(team)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
users = [member.user for member in team.members]
|
||||
assert john not in users
|
||||
assert jack in users
|
||||
assert self.session.query(self.TeamMember).count() == 1
|
65
tests/test_phonenumber_type.py
Normal file
65
tests/test_phonenumber_type.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils import PhoneNumber
|
||||
|
||||
|
||||
class TestPhoneNumber(object):
|
||||
def setup_method(self, method):
|
||||
self.valid_phone_numbers = [
|
||||
'040 1234567',
|
||||
'+358 401234567',
|
||||
'09 2501234',
|
||||
'+358 92501234',
|
||||
'0800 939393',
|
||||
'09 4243 0456',
|
||||
'0600 900 500'
|
||||
]
|
||||
self.invalid_phone_numbers = [
|
||||
'abc',
|
||||
'+040 1234567',
|
||||
'0111234567',
|
||||
'358'
|
||||
]
|
||||
|
||||
def test_valid_phone_numbers(self):
|
||||
for raw_number in self.valid_phone_numbers:
|
||||
phone_number = PhoneNumber(raw_number, 'FI')
|
||||
assert phone_number.is_valid_number()
|
||||
|
||||
def test_invalid_phone_numbers(self):
|
||||
for raw_number in self.invalid_phone_numbers:
|
||||
try:
|
||||
phone_number = PhoneNumber(raw_number, 'FI')
|
||||
assert not phone_number.is_valid_number()
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_phone_number_attributes(self):
|
||||
phone_number = PhoneNumber('+358401234567')
|
||||
assert phone_number.e164 == u'+358401234567'
|
||||
assert phone_number.international == u'+358 40 1234567'
|
||||
assert phone_number.national == u'040 1234567'
|
||||
|
||||
|
||||
class TestPhoneNumberType(TestCase):
|
||||
def setup_method(self, method):
|
||||
super(TestPhoneNumberType, self).setup_method(method)
|
||||
self.phone_number = PhoneNumber(
|
||||
'040 1234567',
|
||||
'FI'
|
||||
)
|
||||
self.user = self.User()
|
||||
self.user.name = u'Someone'
|
||||
self.user.phone_number = self.phone_number
|
||||
self.session.add(self.user)
|
||||
self.session.commit()
|
||||
|
||||
def test_query_returns_phone_number_object(self):
|
||||
queried_user = self.session.query(self.User).first()
|
||||
assert queried_user.phone_number == self.phone_number
|
||||
|
||||
def test_phone_number_is_stored_as_string(self):
|
||||
result = self.session.execute(
|
||||
'SELECT phone_number FROM "user" WHERE id=:param',
|
||||
{'param': self.user.id}
|
||||
)
|
||||
assert result.first()[0] == u'+358401234567'
|
47
tests/test_utility_functions.py
Normal file
47
tests/test_utility_functions.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import escape_like, sort_query
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestEscapeLike(TestCase):
|
||||
def test_escapes_wildcards(self):
|
||||
assert escape_like('_*%') == '*_***%'
|
||||
|
||||
|
||||
class TestSortQuery(TestCase):
|
||||
def test_without_sort_param_returns_the_query_object_untouched(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_column_ascending(self):
|
||||
query = sort_query(self.session.query(self.Article), 'name')
|
||||
assert 'ORDER BY article.name ASC' in str(query)
|
||||
|
||||
def test_sort_by_column_descending(self):
|
||||
query = sort_query(self.session.query(self.Article), '-name')
|
||||
assert 'ORDER BY article.name DESC' in str(query)
|
||||
|
||||
def test_skips_unknown_columns(self):
|
||||
query = self.session.query(self.Article)
|
||||
sorted_query = sort_query(query, '-unknown')
|
||||
assert query == sorted_query
|
||||
|
||||
def test_sort_by_calculated_value_ascending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, 'articles')
|
||||
assert 'ORDER BY articles ASC' in str(query)
|
||||
|
||||
def test_sort_by_calculated_value_descending(self):
|
||||
query = self.session.query(
|
||||
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||
)
|
||||
query = sort_query(query, '-articles')
|
||||
assert 'ORDER BY articles DESC' in str(query)
|
||||
|
||||
def test_sort_by_joined_table_column(self):
|
||||
query = self.session.query(self.Article).join(self.Article.category)
|
||||
sorted_query = sort_query(query, 'category-name')
|
||||
assert 'category.name ASC' in str(sorted_query)
|
Reference in New Issue
Block a user