Added PhoneNumberType type decorator

This commit is contained in:
Vesa Uimonen
2013-03-20 13:26:15 +02:00
parent 37e1072787
commit 804061a3d2
2 changed files with 127 additions and 46 deletions

View File

@@ -1,3 +1,4 @@
import phonenumbers
from functools import wraps
from sqlalchemy.orm import defer
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
@@ -5,6 +6,37 @@ from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.sql.expression import desc, asc
from sqlalchemy import types
class PhoneNumberType(types.TypeDecorator):
"""
Changes PhoneNumber objects to a string representation on the way in and
changes them back to PhoneNumber objects on the way out. If E164 is used
as storing format, no country code is needed for parsing the database
value to PhoneNumber object.
"""
STORE_FORMAT = phonenumbers.PhoneNumberFormat.E164
impl = types.Unicode(20)
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
super(PhoneNumberType, self).__init__(*args, **kwargs)
self.country_code = country_code
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
return phonenumbers.format_number(
value,
self.STORE_FORMAT
)
def process_result_value(self, value, dialect):
if self.STORE_FORMAT == phonenumbers.PhoneNumberFormat.E164:
return phonenumbers.parse(value)
return phonenumbers.parse(
value,
self.country_code
)
class InstrumentedList(_InstrumentedList):

141
tests.py
View File

@@ -1,95 +1,144 @@
import phonenumbers
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import escape_like, sort_query, InstrumentedList
engine = create_engine(
'sqlite:///'
from sqlalchemy_utils import (
escape_like,
sort_query,
InstrumentedList,
PhoneNumberType
)
Base = declarative_base()
Session = sessionmaker(bind=engine)
session = Session()
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class TestCase(object):
def setup_method(self, method):
self.engine = create_engine('sqlite:///:memory:')
self.Base = declarative_base()
self.create_models()
self.Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def teardown_method(self, method):
self.session.close_all()
self.Base.metadata.drop_all(self.engine)
self.engine.dispose()
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
phone_number = sa.Column(PhoneNumberType())
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
self.User = User
self.Category = Category
self.Article = Article
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
collection_class=InstrumentedList
)
)
class TestInstrumentedList(object):
class TestInstrumentedList(TestCase):
def test_any_returns_true_if_member_has_attr_defined(self):
category = Category()
category.articles.append(Article())
category.articles.append(Article(name=u'some name'))
category = self.Category()
category.articles.append(self.Article())
category.articles.append(self.Article(name=u'some name'))
assert category.articles.any('name')
def test_any_returns_false_if_no_member_has_attr_defined(self):
category = Category()
category.articles.append(Article())
category = self.Category()
category.articles.append(self.Article())
assert not category.articles.any('name')
class TestEscapeLike(object):
class TestEscapeLike(TestCase):
def test_escapes_wildcards(self):
assert escape_like('_*%') == '*_***%'
class TestSortQuery(object):
class TestSortQuery(TestCase):
def test_without_sort_param_returns_the_query_object_untouched(self):
query = session.query(Article)
query = self.session.query(self.Article)
sorted_query = sort_query(query, '')
assert query == sorted_query
def test_sort_by_column_ascending(self):
query = sort_query(session.query(Article), 'name')
query = sort_query(self.session.query(self.Article), 'name')
assert 'ORDER BY article.name ASC' in str(query)
def test_sort_by_column_descending(self):
query = sort_query(session.query(Article), '-name')
query = sort_query(self.session.query(self.Article), '-name')
assert 'ORDER BY article.name DESC' in str(query)
def test_skips_unknown_columns(self):
query = session.query(Article)
query = self.session.query(self.Article)
sorted_query = sort_query(query, '-unknown')
assert query == sorted_query
def test_sort_by_calculated_value_ascending(self):
query = session.query(
Category, sa.func.count(Article.id).label('articles')
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, 'articles')
assert 'ORDER BY articles ASC' in str(query)
def test_sort_by_calculated_value_descending(self):
query = session.query(
Category, sa.func.count(Article.id).label('articles')
query = self.session.query(
self.Category, sa.func.count(self.Article.id).label('articles')
)
query = sort_query(query, '-articles')
assert 'ORDER BY articles DESC' in str(query)
def test_sort_by_joined_table_column(self):
query = session.query(Article).join(Article.category)
query = self.session.query(self.Article).join(self.Article.category)
sorted_query = sort_query(query, 'category-name')
assert 'category.name ASC' in str(sorted_query)
class TestPhoneNumberType(TestCase):
def setup_method(self, method):
super(TestPhoneNumberType, self).setup_method(method)
self.phone_number = phonenumbers.parse(
'040 1234567',
'FI'
)
self.user = self.User()
self.user.name = u'Someone'
self.user.phone_number = self.phone_number
self.session.add(self.user)
self.session.commit()
def test_query_returns_phone_number_object(self):
queried_user = self.session.query(self.User).first()
assert queried_user.phone_number == self.phone_number
def test_phone_number_is_stored_as_string(self):
result = self.session.execute(
'SELECT phone_number FROM user WHERE id=:param',
{'param': self.user.id}
)
assert result.first()[0] == u'+358401234567'