From f9b97fd0f572ee82d20bf12de4006bc4a55c682b Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Sat, 4 Jan 2014 15:21:34 +0200 Subject: [PATCH] Support for SA 0.9 JSON type --- sqlalchemy_utils/types/json.py | 30 ++++++++++++++++++++---------- tests/types/test_json.py | 12 ++++++++++-- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/sqlalchemy_utils/types/json.py b/sqlalchemy_utils/types/json.py index 7214d82..58b6397 100644 --- a/sqlalchemy_utils/types/json.py +++ b/sqlalchemy_utils/types/json.py @@ -10,16 +10,19 @@ import six from sqlalchemy.dialects.postgresql.base import ischema_names from ..exceptions import ImproperlyConfigured +try: + from sqlalchemy.dialects.postgresql import JSON + has_postgres_json = True +except ImportError: + class PostgresJSONType(sa.types.UserDefinedType): + """ + Text search vector type for postgresql. + """ + def get_col_spec(self): + return 'json' -class PostgresJSONType(sa.types.UserDefinedType): - """ - Text search vector type for postgresql. - """ - def get_col_spec(self): - return 'json' - - -ischema_names['json'] = PostgresJSONType + ischema_names['json'] = PostgresJSONType + has_postgres_json = False class JSONType(sa.types.TypeDecorator): @@ -60,16 +63,23 @@ class JSONType(sa.types.TypeDecorator): def load_dialect_impl(self, dialect): if dialect.name == 'postgresql': # Use the native JSON type. - return dialect.type_descriptor(PostgresJSONType()) + if has_postgres_json: + return dialect.type_descriptor(JSON()) + else: + return dialect.type_descriptor(PostgresJSONType()) else: return dialect.type_descriptor(self.impl) def process_bind_param(self, value, dialect): + if dialect.name == 'postgresql' and has_postgres_json: + return value if value is not None: value = six.text_type(json.dumps(value)) return value def process_result_value(self, value, dialect): + if dialect.name == 'postgresql' and has_postgres_json: + return value if value is not None: value = json.loads(value) return value diff --git a/tests/types/test_json.py b/tests/types/test_json.py index cf2ffb1..5115883 100644 --- a/tests/types/test_json.py +++ b/tests/types/test_json.py @@ -5,8 +5,7 @@ from sqlalchemy_utils.types import json from tests import TestCase -@mark.skipif('json.json is None') -class TestJSONType(TestCase): +class JSONTestCase(TestCase): def create_models(self): class Document(self.Base): __tablename__ = 'document' @@ -47,3 +46,12 @@ class TestJSONType(TestCase): document = self.session.query(self.Document).first() assert document.json == {'something': u'äääööö'} + + +@mark.skipif('json.json is None') +class TestSqliteJSONType(JSONTestCase): + pass + +@mark.skipif('json.json is None') +class TestPostgresJSONType(JSONTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'