From 804061a3d262528d8f6ec41a30275e7b7dffdf7d Mon Sep 17 00:00:00 2001 From: Vesa Uimonen Date: Wed, 20 Mar 2013 13:26:15 +0200 Subject: [PATCH] Added PhoneNumberType type decorator --- sqlalchemy_utils/__init__.py | 32 ++++++++ tests.py | 141 +++++++++++++++++++++++------------ 2 files changed, 127 insertions(+), 46 deletions(-) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index b631e69..2028ab0 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,3 +1,4 @@ +import phonenumbers from functools import wraps from sqlalchemy.orm import defer from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList @@ -5,6 +6,37 @@ from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.sql.expression import desc, asc +from sqlalchemy import types + + +class PhoneNumberType(types.TypeDecorator): + """ + Changes PhoneNumber objects to a string representation on the way in and + changes them back to PhoneNumber objects on the way out. If E164 is used + as storing format, no country code is needed for parsing the database + value to PhoneNumber object. + """ + STORE_FORMAT = phonenumbers.PhoneNumberFormat.E164 + impl = types.Unicode(20) + + def __init__(self, country_code='US', max_length=20, *args, **kwargs): + super(PhoneNumberType, self).__init__(*args, **kwargs) + self.country_code = country_code + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + return phonenumbers.format_number( + value, + self.STORE_FORMAT + ) + + def process_result_value(self, value, dialect): + if self.STORE_FORMAT == phonenumbers.PhoneNumberFormat.E164: + return phonenumbers.parse(value) + return phonenumbers.parse( + value, + self.country_code + ) class InstrumentedList(_InstrumentedList): diff --git a/tests.py b/tests.py index ee775fb..d0cbfc6 100644 --- a/tests.py +++ b/tests.py @@ -1,95 +1,144 @@ +import phonenumbers import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_utils import escape_like, sort_query, InstrumentedList - - -engine = create_engine( - 'sqlite:///' +from sqlalchemy_utils import ( + escape_like, + sort_query, + InstrumentedList, + PhoneNumberType ) -Base = declarative_base() - -Session = sessionmaker(bind=engine) -session = Session() -class Category(Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) +class TestCase(object): + + def setup_method(self, method): + self.engine = create_engine('sqlite:///:memory:') + self.Base = declarative_base() + + self.create_models() + self.Base.metadata.create_all(self.engine) + + Session = sessionmaker(bind=self.engine) + self.session = Session() + + def teardown_method(self, method): + self.session.close_all() + self.Base.metadata.drop_all(self.engine) + self.engine.dispose() + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + phone_number = sa.Column(PhoneNumberType()) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles', + collection_class=InstrumentedList + ) + ) + + self.User = User + self.Category = Category + self.Article = Article -class Article(Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) - - category = sa.orm.relationship( - Category, - primaryjoin=category_id == Category.id, - backref=sa.orm.backref( - 'articles', - collection_class=InstrumentedList - ) - ) - - -class TestInstrumentedList(object): +class TestInstrumentedList(TestCase): def test_any_returns_true_if_member_has_attr_defined(self): - category = Category() - category.articles.append(Article()) - category.articles.append(Article(name=u'some name')) + category = self.Category() + category.articles.append(self.Article()) + category.articles.append(self.Article(name=u'some name')) assert category.articles.any('name') def test_any_returns_false_if_no_member_has_attr_defined(self): - category = Category() - category.articles.append(Article()) + category = self.Category() + category.articles.append(self.Article()) assert not category.articles.any('name') -class TestEscapeLike(object): +class TestEscapeLike(TestCase): def test_escapes_wildcards(self): assert escape_like('_*%') == '*_***%' -class TestSortQuery(object): +class TestSortQuery(TestCase): def test_without_sort_param_returns_the_query_object_untouched(self): - query = session.query(Article) + query = self.session.query(self.Article) sorted_query = sort_query(query, '') assert query == sorted_query def test_sort_by_column_ascending(self): - query = sort_query(session.query(Article), 'name') + query = sort_query(self.session.query(self.Article), 'name') assert 'ORDER BY article.name ASC' in str(query) def test_sort_by_column_descending(self): - query = sort_query(session.query(Article), '-name') + query = sort_query(self.session.query(self.Article), '-name') assert 'ORDER BY article.name DESC' in str(query) def test_skips_unknown_columns(self): - query = session.query(Article) + query = self.session.query(self.Article) sorted_query = sort_query(query, '-unknown') assert query == sorted_query def test_sort_by_calculated_value_ascending(self): - query = session.query( - Category, sa.func.count(Article.id).label('articles') + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') ) query = sort_query(query, 'articles') assert 'ORDER BY articles ASC' in str(query) def test_sort_by_calculated_value_descending(self): - query = session.query( - Category, sa.func.count(Article.id).label('articles') + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') ) query = sort_query(query, '-articles') assert 'ORDER BY articles DESC' in str(query) def test_sort_by_joined_table_column(self): - query = session.query(Article).join(Article.category) + query = self.session.query(self.Article).join(self.Article.category) sorted_query = sort_query(query, 'category-name') assert 'category.name ASC' in str(sorted_query) + + +class TestPhoneNumberType(TestCase): + def setup_method(self, method): + super(TestPhoneNumberType, self).setup_method(method) + self.phone_number = phonenumbers.parse( + '040 1234567', + 'FI' + ) + self.user = self.User() + self.user.name = u'Someone' + self.user.phone_number = self.phone_number + self.session.add(self.user) + self.session.commit() + + def test_query_returns_phone_number_object(self): + queried_user = self.session.query(self.User).first() + assert queried_user.phone_number == self.phone_number + + def test_phone_number_is_stored_as_string(self): + result = self.session.execute( + 'SELECT phone_number FROM user WHERE id=:param', + {'param': self.user.id} + ) + assert result.first()[0] == u'+358401234567'