Added PhoneNumberType type decorator

This commit is contained in:
Vesa Uimonen
2013-03-20 13:26:15 +02:00
parent 37e1072787
commit 804061a3d2
2 changed files with 127 additions and 46 deletions

View File

@@ -1,3 +1,4 @@
import phonenumbers
from functools import wraps from functools import wraps
from sqlalchemy.orm import defer from sqlalchemy.orm import defer
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
@@ -5,6 +6,37 @@ from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.sql.expression import desc, asc from sqlalchemy.sql.expression import desc, asc
from sqlalchemy import types
class PhoneNumberType(types.TypeDecorator):
"""
Changes PhoneNumber objects to a string representation on the way in and
changes them back to PhoneNumber objects on the way out. If E164 is used
as storing format, no country code is needed for parsing the database
value to PhoneNumber object.
"""
STORE_FORMAT = phonenumbers.PhoneNumberFormat.E164
impl = types.Unicode(20)
def __init__(self, country_code='US', max_length=20, *args, **kwargs):
super(PhoneNumberType, self).__init__(*args, **kwargs)
self.country_code = country_code
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
return phonenumbers.format_number(
value,
self.STORE_FORMAT
)
def process_result_value(self, value, dialect):
if self.STORE_FORMAT == phonenumbers.PhoneNumberFormat.E164:
return phonenumbers.parse(value)
return phonenumbers.parse(
value,
self.country_code
)
class InstrumentedList(_InstrumentedList): class InstrumentedList(_InstrumentedList):

107
tests.py
View File

@@ -1,28 +1,48 @@
import phonenumbers
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import escape_like, sort_query, InstrumentedList from sqlalchemy_utils import (
escape_like,
sort_query,
engine = create_engine( InstrumentedList,
'sqlite:///' PhoneNumberType
) )
Base = declarative_base()
Session = sessionmaker(bind=engine)
session = Session()
class Category(Base): class TestCase(object):
def setup_method(self, method):
self.engine = create_engine('sqlite:///:memory:')
self.Base = declarative_base()
self.create_models()
self.Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def teardown_method(self, method):
self.session.close_all()
self.Base.metadata.drop_all(self.engine)
self.engine.dispose()
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
phone_number = sa.Column(PhoneNumberType())
class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
class Article(self.Base):
class Article(Base):
__tablename__ = 'article' __tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
@@ -37,59 +57,88 @@ class Article(Base):
) )
) )
self.User = User
self.Category = Category
self.Article = Article
class TestInstrumentedList(object):
class TestInstrumentedList(TestCase):
def test_any_returns_true_if_member_has_attr_defined(self): def test_any_returns_true_if_member_has_attr_defined(self):
category = Category() category = self.Category()
category.articles.append(Article()) category.articles.append(self.Article())
category.articles.append(Article(name=u'some name')) category.articles.append(self.Article(name=u'some name'))
assert category.articles.any('name') assert category.articles.any('name')
def test_any_returns_false_if_no_member_has_attr_defined(self): def test_any_returns_false_if_no_member_has_attr_defined(self):
category = Category() category = self.Category()
category.articles.append(Article()) category.articles.append(self.Article())
assert not category.articles.any('name') assert not category.articles.any('name')
class TestEscapeLike(object): class TestEscapeLike(TestCase):
def test_escapes_wildcards(self): def test_escapes_wildcards(self):
assert escape_like('_*%') == '*_***%' assert escape_like('_*%') == '*_***%'
class TestSortQuery(object): class TestSortQuery(TestCase):
def test_without_sort_param_returns_the_query_object_untouched(self): def test_without_sort_param_returns_the_query_object_untouched(self):
query = session.query(Article) query = self.session.query(self.Article)
sorted_query = sort_query(query, '') sorted_query = sort_query(query, '')
assert query == sorted_query assert query == sorted_query
def test_sort_by_column_ascending(self): def test_sort_by_column_ascending(self):
query = sort_query(session.query(Article), 'name') query = sort_query(self.session.query(self.Article), 'name')
assert 'ORDER BY article.name ASC' in str(query) assert 'ORDER BY article.name ASC' in str(query)
def test_sort_by_column_descending(self): def test_sort_by_column_descending(self):
query = sort_query(session.query(Article), '-name') query = sort_query(self.session.query(self.Article), '-name')
assert 'ORDER BY article.name DESC' in str(query) assert 'ORDER BY article.name DESC' in str(query)
def test_skips_unknown_columns(self): def test_skips_unknown_columns(self):
query = session.query(Article) query = self.session.query(self.Article)
sorted_query = sort_query(query, '-unknown') sorted_query = sort_query(query, '-unknown')
assert query == sorted_query assert query == sorted_query
def test_sort_by_calculated_value_ascending(self): def test_sort_by_calculated_value_ascending(self):
query = session.query( query = self.session.query(
Category, sa.func.count(Article.id).label('articles') self.Category, sa.func.count(self.Article.id).label('articles')
) )
query = sort_query(query, 'articles') query = sort_query(query, 'articles')
assert 'ORDER BY articles ASC' in str(query) assert 'ORDER BY articles ASC' in str(query)
def test_sort_by_calculated_value_descending(self): def test_sort_by_calculated_value_descending(self):
query = session.query( query = self.session.query(
Category, sa.func.count(Article.id).label('articles') self.Category, sa.func.count(self.Article.id).label('articles')
) )
query = sort_query(query, '-articles') query = sort_query(query, '-articles')
assert 'ORDER BY articles DESC' in str(query) assert 'ORDER BY articles DESC' in str(query)
def test_sort_by_joined_table_column(self): def test_sort_by_joined_table_column(self):
query = session.query(Article).join(Article.category) query = self.session.query(self.Article).join(self.Article.category)
sorted_query = sort_query(query, 'category-name') sorted_query = sort_query(query, 'category-name')
assert 'category.name ASC' in str(sorted_query) assert 'category.name ASC' in str(sorted_query)
class TestPhoneNumberType(TestCase):
def setup_method(self, method):
super(TestPhoneNumberType, self).setup_method(method)
self.phone_number = phonenumbers.parse(
'040 1234567',
'FI'
)
self.user = self.User()
self.user.name = u'Someone'
self.user.phone_number = self.phone_number
self.session.add(self.user)
self.session.commit()
def test_query_returns_phone_number_object(self):
queried_user = self.session.query(self.User).first()
assert queried_user.phone_number == self.phone_number
def test_phone_number_is_stored_as_string(self):
result = self.session.execute(
'SELECT phone_number FROM user WHERE id=:param',
{'param': self.user.id}
)
assert result.first()[0] == u'+358401234567'