From 54dff643daaf38223ddcaed94f8c2ca00f155cb4 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 25 Jul 2013 14:12:52 +0300 Subject: [PATCH] Added ArrowType, moved coercion_listener test setup to init --- setup.py | 1 + sqlalchemy_utils/types/__init__.py | 2 ++ sqlalchemy_utils/types/arrow.py | 45 ++++++++++++++++++++++++++++++ tests/__init__.py | 4 +++ tests/test_arrow.py | 33 ++++++++++++++++++++++ tests/test_color.py | 4 +-- tests/test_password.py | 3 +- tests/test_phonenumber_type.py | 3 +- tests/test_uuid.py | 4 +-- 9 files changed, 89 insertions(+), 10 deletions(-) create mode 100644 sqlalchemy_utils/types/arrow.py create mode 100644 tests/test_arrow.py diff --git a/setup.py b/setup.py index dfa4da2..5f55666 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ setup( 'flexmock>=0.9.7', 'psycopg2>=2.4.6' ], + 'arrow': ['arrow>=0.3.4'], 'phone': ['phonenumbers3k==5.6b1'], 'password': ['passlib >= 1.6, < 2.0'], 'color': ['colour>=0.0.4'] diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index c8c60ae..b91867a 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -2,6 +2,7 @@ from functools import wraps from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from sqlalchemy import types from sqlalchemy.dialects.postgresql.base import ischema_names +from .arrow import ArrowType from .color import ColorType from .email import EmailType from .ip_address import IPAddressType @@ -18,6 +19,7 @@ from .uuid import UUIDType __all__ = ( + ArrowType, ColorType, EmailType, IPAddressType, diff --git a/sqlalchemy_utils/types/arrow.py b/sqlalchemy_utils/types/arrow.py new file mode 100644 index 0000000..4d669e5 --- /dev/null +++ b/sqlalchemy_utils/types/arrow.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from collections import Iterable +from datetime import datetime +import six + +arrow = None +try: + import arrow +except: + pass +from sqlalchemy import types +from sqlalchemy_utils import ImproperlyConfigured + + +class ArrowType(types.TypeDecorator): + impl = types.DateTime + + def __init__(self, *args, **kwargs): + if not arrow: + raise ImproperlyConfigured( + "'arrow' package is required to use 'ArrowType'" + ) + + super(ArrowType, self).__init__(*args, **kwargs) + + def process_bind_param(self, value, dialect): + if value: + return value.datetime + return value + + def process_result_value(self, value, dialect): + if value: + return arrow.get(value) + return value + + def coercion_listener(self, target, value, oldvalue, initiator): + if value is None: + return None + elif isinstance(value, six.string_types): + value = arrow.get(value) + elif isinstance(value, Iterable): + value = arrow.get(*value) + elif isinstance(value, datetime): + value = arrow.get(value) + return value diff --git a/tests/__init__.py b/tests/__init__.py index 12c3d2c..e7c1894 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import InstrumentedList +from sqlalchemy_utils import coercion_listener @sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute') @@ -19,6 +20,9 @@ def count_sql_calls(conn, cursor, statement, parameters, context, executemany): warnings.simplefilter('error', sa.exc.SAWarning) +sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) + + class TestCase(object): dns = 'sqlite:///:memory:' diff --git a/tests/test_arrow.py b/tests/test_arrow.py new file mode 100644 index 0000000..66f8c1c --- /dev/null +++ b/tests/test_arrow.py @@ -0,0 +1,33 @@ +from datetime import datetime +from pytest import mark +import sqlalchemy as sa +from sqlalchemy_utils.types import arrow +from tests import TestCase + + +@mark.xfail('arrow.arrow is None') +class TestArrowDateTimeType(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + created_at = sa.Column(arrow.ArrowType) + + self.Article = Article + + def test_parameter_processing(self): + article = self.Article( + created_at=arrow.arrow.get(datetime(2000, 11, 1)) + ) + + self.session.add(article) + self.session.commit() + + article = self.session.query(self.Article).first() + assert article.created_at.datetime + + def test_string_coercion(self): + article = self.Article( + created_at='1367900664' + ) + assert article.created_at.year == 2013 diff --git a/tests/test_color.py b/tests/test_color.py index de82cb2..617efdf 100644 --- a/tests/test_color.py +++ b/tests/test_color.py @@ -1,6 +1,6 @@ from pytest import mark import sqlalchemy as sa -from sqlalchemy_utils import ColorType, coercion_listener +from sqlalchemy_utils import ColorType from sqlalchemy_utils.types import color from tests import TestCase @@ -17,8 +17,6 @@ class TestColorType(TestCase): return 'Document(%r)' % self.id self.Document = Document - sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) - def test_color_parameter_processing(self): from colour import Color diff --git a/tests/test_password.py b/tests/test_password.py index 1786109..5537d72 100644 --- a/tests/test_password.py +++ b/tests/test_password.py @@ -2,7 +2,7 @@ from pytest import mark import sqlalchemy as sa from tests import TestCase from sqlalchemy_utils.types import password -from sqlalchemy_utils import Password, PasswordType, coercion_listener +from sqlalchemy_utils import Password, PasswordType @mark.xfail('password.passlib is None') @@ -25,7 +25,6 @@ class TestPasswordType(TestCase): return 'User(%r)' % self.id self.User = User - sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) def test_encrypt(self): """Should encrypt the password on setting the attribute.""" diff --git a/tests/test_phonenumber_type.py b/tests/test_phonenumber_type.py index 661182a..744a1f5 100644 --- a/tests/test_phonenumber_type.py +++ b/tests/test_phonenumber_type.py @@ -1,7 +1,7 @@ from pytest import mark from tests import TestCase import sqlalchemy as sa -from sqlalchemy_utils import PhoneNumberType, PhoneNumber, coercion_listener +from sqlalchemy_utils import PhoneNumberType, PhoneNumber from sqlalchemy_utils.types import phone_number @@ -60,7 +60,6 @@ class TestPhoneNumberType(TestCase): phone_number = sa.Column(PhoneNumberType()) self.User = User - sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) def setup_method(self, method): super(TestPhoneNumberType, self).setup_method(method) diff --git a/tests/test_uuid.py b/tests/test_uuid.py index e07465e..5a1de10 100644 --- a/tests/test_uuid.py +++ b/tests/test_uuid.py @@ -1,11 +1,10 @@ import sqlalchemy as sa from tests import TestCase -from sqlalchemy_utils import UUIDType, coercion_listener +from sqlalchemy_utils import UUIDType import uuid class TestUUIDType(TestCase): - def create_models(self): class User(self.Base): __tablename__ = 'user' @@ -15,7 +14,6 @@ class TestUUIDType(TestCase): return 'User(%r)' % self.id self.User = User - sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) def test_commit(self): obj = self.User()