From 2e54381e3336fe2d225e0f59af53f9ccd894dedc Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 7 May 2014 17:17:07 +0300 Subject: [PATCH] Add get_bind function --- CHANGES.rst | 6 +++++ docs/model_helpers.rst | 6 +++++ sqlalchemy_utils/__init__.py | 2 ++ sqlalchemy_utils/functions/__init__.py | 2 ++ sqlalchemy_utils/functions/orm.py | 34 ++++++++++++++++++++++++++ tests/functions/test_get_bind.py | 21 ++++++++++++++++ 6 files changed, 71 insertions(+) create mode 100644 tests/functions/test_get_bind.py diff --git a/CHANGES.rst b/CHANGES.rst index 53ac8c3..c2128ed 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.26.0 (2014-05-xx) +^^^^^^^^^^^^^^^^^^^ + +- Added get_bind + + 0.26.0 (2014-05-07) ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/model_helpers.rst b/docs/model_helpers.rst index 253c1bc..0bfee53 100644 --- a/docs/model_helpers.rst +++ b/docs/model_helpers.rst @@ -16,6 +16,12 @@ escape_like .. autofunction:: escape_like +get_bind +^^^^^^^^ + +.. autofunction:: get_bind + + get_columns ^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 732f797..d5ed830 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -11,6 +11,7 @@ from .functions import ( dependent_objects, drop_database, escape_like, + get_bind, get_columns, get_declarative_base, get_primary_keys, @@ -82,6 +83,7 @@ __all__ = ( force_instant_defaults, generates, generic_relationship, + get_bind, get_columns, get_declarative_base, get_primary_keys, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 5145e1b..7b56fb3 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -13,6 +13,7 @@ from .database import ( ) from .orm import ( dependent_objects, + get_bind, get_columns, get_declarative_base, get_primary_keys, @@ -34,6 +35,7 @@ __all__ = ( 'dependent_objects', 'drop_database', 'escape_like', + 'get_bind', 'get_columns', 'get_declarative_base', 'get_primary_keys', diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index ea2993f..1aa0008 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -10,6 +10,7 @@ import sqlalchemy as sa from sqlalchemy import inspect from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.exc import UnmappedInstanceError from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.session import object_session @@ -17,6 +18,39 @@ from sqlalchemy.orm.util import AliasedInsp from ..query_chain import QueryChain +def get_bind(obj): + """ + Return the bind for given SQLAlchemy Engine / Connection / declarative + model object. + + :param obj: SQLAlchemy Engine / Connection / declarative model object + + :: + + from sqlalchemy_utils import get_bind + + + get_bind(session) # Connection object + + get_bind(user) + + """ + if hasattr(obj, 'bind'): + conn = obj.bind + else: + try: + conn = object_session(obj).bind + except UnmappedInstanceError: + conn = obj + + if not hasattr(conn, 'execute'): + raise TypeError( + 'This method accepts only Session, Engine, Connection and ' + 'declarative model objects.' + ) + return conn + + def dependent_objects(obj, foreign_keys=None): """ Return a QueryChain that iterates through all dependent objects for given diff --git a/tests/functions/test_get_bind.py b/tests/functions/test_get_bind.py new file mode 100644 index 0000000..c10a5f0 --- /dev/null +++ b/tests/functions/test_get_bind.py @@ -0,0 +1,21 @@ +from pytest import raises + +from sqlalchemy_utils import get_bind +from tests import TestCase + + +class TestGetBind(TestCase): + def test_with_session(self): + assert get_bind(self.session) == self.connection + + def test_with_connection(self): + assert get_bind(self.connection) == self.connection + + def test_with_model_object(self): + article = self.Article() + self.session.add(article) + assert get_bind(article) == self.connection + + def test_with_unknown_type(self): + with raises(TypeError): + get_bind(None)