Added PhoneNumberType type decorator
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import phonenumbers
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from sqlalchemy.orm import defer
|
from sqlalchemy.orm import defer
|
||||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||||
@@ -5,6 +6,37 @@ from sqlalchemy.orm.mapper import Mapper
|
|||||||
from sqlalchemy.orm.query import _ColumnEntity
|
from sqlalchemy.orm.query import _ColumnEntity
|
||||||
from sqlalchemy.orm.properties import ColumnProperty
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
from sqlalchemy.sql.expression import desc, asc
|
from sqlalchemy.sql.expression import desc, asc
|
||||||
|
from sqlalchemy import types
|
||||||
|
|
||||||
|
|
||||||
|
class PhoneNumberType(types.TypeDecorator):
|
||||||
|
"""
|
||||||
|
Changes PhoneNumber objects to a string representation on the way in and
|
||||||
|
changes them back to PhoneNumber objects on the way out. If E164 is used
|
||||||
|
as storing format, no country code is needed for parsing the database
|
||||||
|
value to PhoneNumber object.
|
||||||
|
"""
|
||||||
|
STORE_FORMAT = phonenumbers.PhoneNumberFormat.E164
|
||||||
|
impl = types.Unicode(20)
|
||||||
|
|
||||||
|
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
|
||||||
|
super(PhoneNumberType, self).__init__(*args, **kwargs)
|
||||||
|
self.country_code = country_code
|
||||||
|
self.impl = types.Unicode(max_length)
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
return phonenumbers.format_number(
|
||||||
|
value,
|
||||||
|
self.STORE_FORMAT
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if self.STORE_FORMAT == phonenumbers.PhoneNumberFormat.E164:
|
||||||
|
return phonenumbers.parse(value)
|
||||||
|
return phonenumbers.parse(
|
||||||
|
value,
|
||||||
|
self.country_code
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InstrumentedList(_InstrumentedList):
|
class InstrumentedList(_InstrumentedList):
|
||||||
|
107
tests.py
107
tests.py
@@ -1,28 +1,48 @@
|
|||||||
|
import phonenumbers
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from sqlalchemy_utils import escape_like, sort_query, InstrumentedList
|
from sqlalchemy_utils import (
|
||||||
|
escape_like,
|
||||||
|
sort_query,
|
||||||
engine = create_engine(
|
InstrumentedList,
|
||||||
'sqlite:///'
|
PhoneNumberType
|
||||||
)
|
)
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
Session = sessionmaker(bind=engine)
|
|
||||||
session = Session()
|
|
||||||
|
|
||||||
|
|
||||||
class Category(Base):
|
class TestCase(object):
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.engine = create_engine('sqlite:///:memory:')
|
||||||
|
self.Base = declarative_base()
|
||||||
|
|
||||||
|
self.create_models()
|
||||||
|
self.Base.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
Session = sessionmaker(bind=self.engine)
|
||||||
|
self.session = Session()
|
||||||
|
|
||||||
|
def teardown_method(self, method):
|
||||||
|
self.session.close_all()
|
||||||
|
self.Base.metadata.drop_all(self.engine)
|
||||||
|
self.engine.dispose()
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
phone_number = sa.Column(PhoneNumberType())
|
||||||
|
|
||||||
|
class Category(self.Base):
|
||||||
__tablename__ = 'category'
|
__tablename__ = 'category'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class Article(self.Base):
|
||||||
class Article(Base):
|
|
||||||
__tablename__ = 'article'
|
__tablename__ = 'article'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
@@ -37,59 +57,88 @@ class Article(Base):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Category = Category
|
||||||
|
self.Article = Article
|
||||||
|
|
||||||
class TestInstrumentedList(object):
|
|
||||||
|
class TestInstrumentedList(TestCase):
|
||||||
def test_any_returns_true_if_member_has_attr_defined(self):
|
def test_any_returns_true_if_member_has_attr_defined(self):
|
||||||
category = Category()
|
category = self.Category()
|
||||||
category.articles.append(Article())
|
category.articles.append(self.Article())
|
||||||
category.articles.append(Article(name=u'some name'))
|
category.articles.append(self.Article(name=u'some name'))
|
||||||
assert category.articles.any('name')
|
assert category.articles.any('name')
|
||||||
|
|
||||||
def test_any_returns_false_if_no_member_has_attr_defined(self):
|
def test_any_returns_false_if_no_member_has_attr_defined(self):
|
||||||
category = Category()
|
category = self.Category()
|
||||||
category.articles.append(Article())
|
category.articles.append(self.Article())
|
||||||
assert not category.articles.any('name')
|
assert not category.articles.any('name')
|
||||||
|
|
||||||
|
|
||||||
class TestEscapeLike(object):
|
class TestEscapeLike(TestCase):
|
||||||
def test_escapes_wildcards(self):
|
def test_escapes_wildcards(self):
|
||||||
assert escape_like('_*%') == '*_***%'
|
assert escape_like('_*%') == '*_***%'
|
||||||
|
|
||||||
|
|
||||||
class TestSortQuery(object):
|
class TestSortQuery(TestCase):
|
||||||
def test_without_sort_param_returns_the_query_object_untouched(self):
|
def test_without_sort_param_returns_the_query_object_untouched(self):
|
||||||
query = session.query(Article)
|
query = self.session.query(self.Article)
|
||||||
sorted_query = sort_query(query, '')
|
sorted_query = sort_query(query, '')
|
||||||
assert query == sorted_query
|
assert query == sorted_query
|
||||||
|
|
||||||
def test_sort_by_column_ascending(self):
|
def test_sort_by_column_ascending(self):
|
||||||
query = sort_query(session.query(Article), 'name')
|
query = sort_query(self.session.query(self.Article), 'name')
|
||||||
assert 'ORDER BY article.name ASC' in str(query)
|
assert 'ORDER BY article.name ASC' in str(query)
|
||||||
|
|
||||||
def test_sort_by_column_descending(self):
|
def test_sort_by_column_descending(self):
|
||||||
query = sort_query(session.query(Article), '-name')
|
query = sort_query(self.session.query(self.Article), '-name')
|
||||||
assert 'ORDER BY article.name DESC' in str(query)
|
assert 'ORDER BY article.name DESC' in str(query)
|
||||||
|
|
||||||
def test_skips_unknown_columns(self):
|
def test_skips_unknown_columns(self):
|
||||||
query = session.query(Article)
|
query = self.session.query(self.Article)
|
||||||
sorted_query = sort_query(query, '-unknown')
|
sorted_query = sort_query(query, '-unknown')
|
||||||
assert query == sorted_query
|
assert query == sorted_query
|
||||||
|
|
||||||
def test_sort_by_calculated_value_ascending(self):
|
def test_sort_by_calculated_value_ascending(self):
|
||||||
query = session.query(
|
query = self.session.query(
|
||||||
Category, sa.func.count(Article.id).label('articles')
|
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||||
)
|
)
|
||||||
query = sort_query(query, 'articles')
|
query = sort_query(query, 'articles')
|
||||||
assert 'ORDER BY articles ASC' in str(query)
|
assert 'ORDER BY articles ASC' in str(query)
|
||||||
|
|
||||||
def test_sort_by_calculated_value_descending(self):
|
def test_sort_by_calculated_value_descending(self):
|
||||||
query = session.query(
|
query = self.session.query(
|
||||||
Category, sa.func.count(Article.id).label('articles')
|
self.Category, sa.func.count(self.Article.id).label('articles')
|
||||||
)
|
)
|
||||||
query = sort_query(query, '-articles')
|
query = sort_query(query, '-articles')
|
||||||
assert 'ORDER BY articles DESC' in str(query)
|
assert 'ORDER BY articles DESC' in str(query)
|
||||||
|
|
||||||
def test_sort_by_joined_table_column(self):
|
def test_sort_by_joined_table_column(self):
|
||||||
query = session.query(Article).join(Article.category)
|
query = self.session.query(self.Article).join(self.Article.category)
|
||||||
sorted_query = sort_query(query, 'category-name')
|
sorted_query = sort_query(query, 'category-name')
|
||||||
assert 'category.name ASC' in str(sorted_query)
|
assert 'category.name ASC' in str(sorted_query)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhoneNumberType(TestCase):
|
||||||
|
def setup_method(self, method):
|
||||||
|
super(TestPhoneNumberType, self).setup_method(method)
|
||||||
|
self.phone_number = phonenumbers.parse(
|
||||||
|
'040 1234567',
|
||||||
|
'FI'
|
||||||
|
)
|
||||||
|
self.user = self.User()
|
||||||
|
self.user.name = u'Someone'
|
||||||
|
self.user.phone_number = self.phone_number
|
||||||
|
self.session.add(self.user)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def test_query_returns_phone_number_object(self):
|
||||||
|
queried_user = self.session.query(self.User).first()
|
||||||
|
assert queried_user.phone_number == self.phone_number
|
||||||
|
|
||||||
|
def test_phone_number_is_stored_as_string(self):
|
||||||
|
result = self.session.execute(
|
||||||
|
'SELECT phone_number FROM user WHERE id=:param',
|
||||||
|
{'param': self.user.id}
|
||||||
|
)
|
||||||
|
assert result.first()[0] == u'+358401234567'
|
||||||
|
Reference in New Issue
Block a user