Support for SA 0.9 JSON type

This commit is contained in:
Konsta Vesterinen
2014-01-04 15:21:34 +02:00
parent 4adf291fe9
commit f9b97fd0f5
2 changed files with 30 additions and 12 deletions

View File

@@ -10,16 +10,19 @@ import six
from sqlalchemy.dialects.postgresql.base import ischema_names from sqlalchemy.dialects.postgresql.base import ischema_names
from ..exceptions import ImproperlyConfigured from ..exceptions import ImproperlyConfigured
try:
from sqlalchemy.dialects.postgresql import JSON
has_postgres_json = True
except ImportError:
class PostgresJSONType(sa.types.UserDefinedType):
"""
Text search vector type for postgresql.
"""
def get_col_spec(self):
return 'json'
class PostgresJSONType(sa.types.UserDefinedType): ischema_names['json'] = PostgresJSONType
""" has_postgres_json = False
Text search vector type for postgresql.
"""
def get_col_spec(self):
return 'json'
ischema_names['json'] = PostgresJSONType
class JSONType(sa.types.TypeDecorator): class JSONType(sa.types.TypeDecorator):
@@ -60,16 +63,23 @@ class JSONType(sa.types.TypeDecorator):
def load_dialect_impl(self, dialect): def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql': if dialect.name == 'postgresql':
# Use the native JSON type. # Use the native JSON type.
return dialect.type_descriptor(PostgresJSONType()) if has_postgres_json:
return dialect.type_descriptor(JSON())
else:
return dialect.type_descriptor(PostgresJSONType())
else: else:
return dialect.type_descriptor(self.impl) return dialect.type_descriptor(self.impl)
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
if dialect.name == 'postgresql' and has_postgres_json:
return value
if value is not None: if value is not None:
value = six.text_type(json.dumps(value)) value = six.text_type(json.dumps(value))
return value return value
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
if dialect.name == 'postgresql' and has_postgres_json:
return value
if value is not None: if value is not None:
value = json.loads(value) value = json.loads(value)
return value return value

View File

@@ -5,8 +5,7 @@ from sqlalchemy_utils.types import json
from tests import TestCase from tests import TestCase
@mark.skipif('json.json is None') class JSONTestCase(TestCase):
class TestJSONType(TestCase):
def create_models(self): def create_models(self):
class Document(self.Base): class Document(self.Base):
__tablename__ = 'document' __tablename__ = 'document'
@@ -47,3 +46,12 @@ class TestJSONType(TestCase):
document = self.session.query(self.Document).first() document = self.session.query(self.Document).first()
assert document.json == {'something': u'äääööö'} assert document.json == {'something': u'äääööö'}
@mark.skipif('json.json is None')
class TestSqliteJSONType(JSONTestCase):
pass
@mark.skipif('json.json is None')
class TestPostgresJSONType(JSONTestCase):
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'