diff --git a/CHANGES.rst b/CHANGES.rst index 0224fce..da0225e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.30.14 (2015-07-23) +^^^^^^^^^^^^^^^^^^^^ + +- Added cast_if utility function + + 0.30.13 (2015-07-21) ^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/orm_helpers.rst b/docs/orm_helpers.rst index eac570e..aeb7b2a 100644 --- a/docs/orm_helpers.rst +++ b/docs/orm_helpers.rst @@ -4,6 +4,12 @@ ORM helpers .. module:: sqlalchemy_utils.functions +cast_if +------- + +.. autofunction:: cast_if + + escape_like ----------- diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 0322ead..9b73602 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -11,6 +11,7 @@ from .expression_parser import ExpressionParser # noqa from .expressions import Asterisk, row_to_json # noqa from .functions import ( # noqa analyze, + cast_if, create_database, create_mock_engine, database_exists, @@ -92,4 +93,4 @@ from .types import ( # noqa WeekDaysType ) -__version__ = '0.30.13' +__version__ = '0.30.14' diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index a8409df..93fa4a2 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -19,6 +19,7 @@ from .foreign_keys import ( # noqa ) from .mock import create_mock_engine, mock_engine # noqa from .orm import ( # noqa + cast_if, get_bind, get_class_by_table, get_column_key, diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 583d6aa..f067276 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -153,6 +153,57 @@ def get_type(expr): raise TypeError("Couldn't inspect type.") +def cast_if(expression, type_): + """ + Produce a CAST expression but only if given expression is not of given type + already. + + Assume we have a model with two fields id (Integer) and name (String). + + :: + + import sqlalchemy as sa + from sqlalchemy_utils import cast_if + + + cast_if(User.id, sa.Integer) # "user".id + cast_if(User.name, sa.String) # "user".name + cast_if(User.id, sa.String) # CAST("user".id AS TEXT) + + + This function supports scalar values as well. + + :: + + cast_if(1, sa.Integer) # 1 + cast_if('text', sa.String) # 'text' + cast_if(1, sa.String) # CAST(1 AS TEXT) + + + :param expression: + A SQL expression, such as a ColumnElement expression or a Python string + which will be coerced into a bound literal value. + :param type_: + A TypeEngine class or instance indicating the type to which the CAST + should apply. + + .. versionadded: 0.30.14 + """ + try: + expr_type = get_type(expression) + except TypeError: + expr_type = expression + check_type = type_().python_type + else: + check_type = type_ + + return ( + sa.cast(expression, type_) + if not isinstance(expr_type, check_type) + else expression + ) + + def get_column_key(model, column): """ Return the key for given column in given model. diff --git a/tests/functions/test_cast_if.py b/tests/functions/test_cast_if.py new file mode 100644 index 0000000..a03d262 --- /dev/null +++ b/tests/functions/test_cast_if.py @@ -0,0 +1,46 @@ +import pytest +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import cast_if + + +@pytest.fixture(scope='class') +def base(): + return declarative_base() + + +@pytest.fixture(scope='class') +def article_cls(base): + class Article(base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + name_synonym = sa.orm.synonym('name') + + return Article + + +class TestCastIf(object): + def test_column(self, article_cls): + expr = article_cls.__table__.c.name + assert cast_if(expr, sa.String) is expr + + def test_column_property(self, article_cls): + expr = article_cls.name.property + assert cast_if(expr, sa.String) is expr + + def test_instrumented_attribute(self, article_cls): + expr = article_cls.name + assert cast_if(expr, sa.String) is expr + + def test_synonym(self, article_cls): + expr = article_cls.name_synonym + assert cast_if(expr, sa.String) is expr + + def test_scalar_selectable(self, article_cls): + expr = sa.select([article_cls.id]).as_scalar() + assert cast_if(expr, sa.Integer) is expr + + def test_scalar(self): + assert cast_if('something', sa.String) == 'something'