diff --git a/sqlalchemy_utils/types/arrow.py b/sqlalchemy_utils/types/arrow.py index 74b4fe7..ec8677c 100644 --- a/sqlalchemy_utils/types/arrow.py +++ b/sqlalchemy_utils/types/arrow.py @@ -66,7 +66,8 @@ class ArrowType(types.TypeDecorator, ScalarCoercible): def process_bind_param(self, value, dialect): if value: - return self._coerce(value).to('UTC').naive + utc_val = self._coerce(value).to('UTC') + return utc_val.datetime if self.impl.timezone else utc_val.naive return value def process_result_value(self, value, dialect): diff --git a/tests/types/test_arrow.py b/tests/types/test_arrow.py index dd33f33..ec4bb23 100644 --- a/tests/types/test_arrow.py +++ b/tests/types/test_arrow.py @@ -2,6 +2,7 @@ from datetime import datetime import pytest import sqlalchemy as sa +from dateutil import tz from sqlalchemy_utils.types import arrow @@ -12,6 +13,8 @@ def Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) created_at = sa.Column(arrow.ArrowType) + published_at = sa.Column(arrow.ArrowType(timezone=True)) + published_at_dt = sa.Column(sa.DateTime(timezone=True)) return Article @@ -61,3 +64,17 @@ class TestArrowDateTimeType(object): clause = Article.created_at > '2015-01-01' compiled = str(clause.compile(compile_kwargs={"literal_binds": True})) assert compiled == 'article.created_at > 2015-01-01' + + @pytest.mark.usefixtures('postgresql_dsn') + def test_timezone(self, session, Article): + timezone = tz.gettz('Europe/Stockholm') + dt = arrow.arrow.get(datetime(2015, 1, 1, 15, 30, 45), timezone) + article = Article(published_at=dt, published_at_dt=dt.datetime) + + session.add(article) + session.commit() + session.expunge_all() + + item = session.query(Article).one() + assert item.published_at.datetime == item.published_at_dt + assert item.published_at.to(timezone) == dt