Added ArrowType, moved coercion_listener test setup to init
This commit is contained in:
1
setup.py
1
setup.py
@@ -55,6 +55,7 @@ setup(
|
||||
'flexmock>=0.9.7',
|
||||
'psycopg2>=2.4.6'
|
||||
],
|
||||
'arrow': ['arrow>=0.3.4'],
|
||||
'phone': ['phonenumbers3k==5.6b1'],
|
||||
'password': ['passlib >= 1.6, < 2.0'],
|
||||
'color': ['colour>=0.0.4']
|
||||
|
@@ -2,6 +2,7 @@ from functools import wraps
|
||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql.base import ischema_names
|
||||
from .arrow import ArrowType
|
||||
from .color import ColorType
|
||||
from .email import EmailType
|
||||
from .ip_address import IPAddressType
|
||||
@@ -18,6 +19,7 @@ from .uuid import UUIDType
|
||||
|
||||
|
||||
__all__ = (
|
||||
ArrowType,
|
||||
ColorType,
|
||||
EmailType,
|
||||
IPAddressType,
|
||||
|
45
sqlalchemy_utils/types/arrow.py
Normal file
45
sqlalchemy_utils/types/arrow.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import absolute_import
|
||||
from collections import Iterable
|
||||
from datetime import datetime
|
||||
import six
|
||||
|
||||
arrow = None
|
||||
try:
|
||||
import arrow
|
||||
except:
|
||||
pass
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy_utils import ImproperlyConfigured
|
||||
|
||||
|
||||
class ArrowType(types.TypeDecorator):
|
||||
impl = types.DateTime
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not arrow:
|
||||
raise ImproperlyConfigured(
|
||||
"'arrow' package is required to use 'ArrowType'"
|
||||
)
|
||||
|
||||
super(ArrowType, self).__init__(*args, **kwargs)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
return value.datetime
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return arrow.get(value)
|
||||
return value
|
||||
|
||||
def coercion_listener(self, target, value, oldvalue, initiator):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, six.string_types):
|
||||
value = arrow.get(value)
|
||||
elif isinstance(value, Iterable):
|
||||
value = arrow.get(*value)
|
||||
elif isinstance(value, datetime):
|
||||
value = arrow.get(value)
|
||||
return value
|
@@ -6,6 +6,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import InstrumentedList
|
||||
from sqlalchemy_utils import coercion_listener
|
||||
|
||||
|
||||
@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
|
||||
@@ -19,6 +20,9 @@ def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
|
||||
warnings.simplefilter('error', sa.exc.SAWarning)
|
||||
|
||||
|
||||
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
|
||||
class TestCase(object):
|
||||
dns = 'sqlite:///:memory:'
|
||||
|
||||
|
33
tests/test_arrow.py
Normal file
33
tests/test_arrow.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from pytest import mark
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.types import arrow
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
@mark.xfail('arrow.arrow is None')
|
||||
class TestArrowDateTimeType(TestCase):
|
||||
def create_models(self):
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
created_at = sa.Column(arrow.ArrowType)
|
||||
|
||||
self.Article = Article
|
||||
|
||||
def test_parameter_processing(self):
|
||||
article = self.Article(
|
||||
created_at=arrow.arrow.get(datetime(2000, 11, 1))
|
||||
)
|
||||
|
||||
self.session.add(article)
|
||||
self.session.commit()
|
||||
|
||||
article = self.session.query(self.Article).first()
|
||||
assert article.created_at.datetime
|
||||
|
||||
def test_string_coercion(self):
|
||||
article = self.Article(
|
||||
created_at='1367900664'
|
||||
)
|
||||
assert article.created_at.year == 2013
|
@@ -1,6 +1,6 @@
|
||||
from pytest import mark
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import ColorType, coercion_listener
|
||||
from sqlalchemy_utils import ColorType
|
||||
from sqlalchemy_utils.types import color
|
||||
from tests import TestCase
|
||||
|
||||
@@ -17,8 +17,6 @@ class TestColorType(TestCase):
|
||||
return 'Document(%r)' % self.id
|
||||
|
||||
self.Document = Document
|
||||
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
|
||||
def test_color_parameter_processing(self):
|
||||
from colour import Color
|
||||
|
@@ -2,7 +2,7 @@ from pytest import mark
|
||||
import sqlalchemy as sa
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils.types import password
|
||||
from sqlalchemy_utils import Password, PasswordType, coercion_listener
|
||||
from sqlalchemy_utils import Password, PasswordType
|
||||
|
||||
|
||||
@mark.xfail('password.passlib is None')
|
||||
@@ -25,7 +25,6 @@ class TestPasswordType(TestCase):
|
||||
return 'User(%r)' % self.id
|
||||
|
||||
self.User = User
|
||||
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
def test_encrypt(self):
|
||||
"""Should encrypt the password on setting the attribute."""
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from pytest import mark
|
||||
from tests import TestCase
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import PhoneNumberType, PhoneNumber, coercion_listener
|
||||
from sqlalchemy_utils import PhoneNumberType, PhoneNumber
|
||||
from sqlalchemy_utils.types import phone_number
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ class TestPhoneNumberType(TestCase):
|
||||
phone_number = sa.Column(PhoneNumberType())
|
||||
|
||||
self.User = User
|
||||
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
def setup_method(self, method):
|
||||
super(TestPhoneNumberType, self).setup_method(method)
|
||||
|
@@ -1,11 +1,10 @@
|
||||
import sqlalchemy as sa
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils import UUIDType, coercion_listener
|
||||
from sqlalchemy_utils import UUIDType
|
||||
import uuid
|
||||
|
||||
|
||||
class TestUUIDType(TestCase):
|
||||
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
@@ -15,7 +14,6 @@ class TestUUIDType(TestCase):
|
||||
return 'User(%r)' % self.id
|
||||
|
||||
self.User = User
|
||||
sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
def test_commit(self):
|
||||
obj = self.User()
|
||||
|
Reference in New Issue
Block a user