From 17603a2f416e3309434b82353c5dcc4d34a652b3 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 15 Jul 2013 11:57:21 -0700 Subject: [PATCH] Initial UUID type implementation. --- sqlalchemy_utils/__init__.py | 6 ++- sqlalchemy_utils/types/__init__.py | 2 + sqlalchemy_utils/types/uuid.py | 65 ++++++++++++++++++++++++++++++ tests/test_uuid.py | 41 +++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 sqlalchemy_utils/types/uuid.py create mode 100644 tests/test_uuid.py diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index cf659af..aef926e 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -19,7 +19,8 @@ from .types import ( NumberRangeType, ScalarListType, ScalarListException, - TSVectorType + TSVectorType, + UUIDType, ) @@ -51,5 +52,6 @@ __all__ = ( ProxyDict, ScalarListType, ScalarListException, - TSVectorType + TSVectorType, + UUIDType, ) diff --git a/sqlalchemy_utils/types/__init__.py b/sqlalchemy_utils/types/__init__.py index 5cf5c49..96d9d95 100644 --- a/sqlalchemy_utils/types/__init__.py +++ b/sqlalchemy_utils/types/__init__.py @@ -13,6 +13,7 @@ from .number_range import ( ) from .phone_number import PhoneNumber, PhoneNumberType from .scalar_list import ScalarListException, ScalarListType +from .uuid import UUIDType __all__ = ( @@ -27,6 +28,7 @@ __all__ = ( PhoneNumberType, ScalarListException, ScalarListType, + UUIDType, ) diff --git a/sqlalchemy_utils/types/uuid.py b/sqlalchemy_utils/types/uuid.py new file mode 100644 index 0000000..43cfb4e --- /dev/null +++ b/sqlalchemy_utils/types/uuid.py @@ -0,0 +1,65 @@ +import uuid +from sqlalchemy import types +from sqlalchemy.dialects import postgresql + + +class UUIDType(types.TypeDecorator): + """ + Stores a UUID in the database natively when it can and falls back to + a BINARY(16) or a CHAR(32) when it can't. + """ + + impl = types.BINARY + + python_type = uuid.UUID + + def __init__(self, binary=True): + """ + :param binary: Whether to use a BINARY(16) or CHAR(32) fallback. + """ + self.binary = binary + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + # Use the native UUID type. + return dialect.type_descriptor(postgresql.UUID()) + + else: + # Fallback to either a BINARY or a CHAR. + kind = types.BINARY(16) if self.binary else types.CHAR(32) + return dialect.type_descriptor(kind) + + @staticmethod + def _coerce(value): + if value and not isinstance(value, uuid.UUID): + try: + value = uuid.UUID(value) + + except (TypeError, ValueError): + value = uuid.UUID(bytes=value) + + return value + + def process_bind_param(self, value, dialect): + if value is None: + return value + + if not isinstance(value, uuid.UUID): + value = self._coerce(None, value, None, None) + + if dialect == 'postgresql': + return str(value) + + return value.bytes if self.binary else value.hex + + def process_result_value(self, value, dialect): + if value is None: + return value + + if dialect == 'postgresql': + return uuid.UUID(value) + + return uuid.UUID(bytes=value) if self.binary else uuid.UUID(value) + + def coercion_listener(self, target, value, oldvalue, initiator): + return self._coerce(value) diff --git a/tests/test_uuid.py b/tests/test_uuid.py new file mode 100644 index 0000000..e07465e --- /dev/null +++ b/tests/test_uuid.py @@ -0,0 +1,41 @@ +import sqlalchemy as sa +from tests import TestCase +from sqlalchemy_utils import UUIDType, coercion_listener +import uuid + + +class TestUUIDType(TestCase): + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(UUIDType, default=uuid.uuid4, primary_key=True) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) + + def test_commit(self): + obj = self.User() + obj.id = uuid.uuid4().hex + + self.session.add(obj) + self.session.commit() + + u = self.session.query(self.User).one() + + assert u.id == obj.id + + def test_coerce(self): + obj = self.User() + obj.id = identifier = uuid.uuid4().hex + + assert isinstance(obj.id, uuid.UUID) + assert obj.id.hex == identifier + + obj.id = identifier = uuid.uuid4().bytes + + assert isinstance(obj.id, uuid.UUID) + assert obj.id.bytes == identifier