Add get_bind function

This commit is contained in:
Konsta Vesterinen
2014-05-07 17:17:07 +03:00
parent f7652ea2d2
commit 2e54381e33
6 changed files with 71 additions and 0 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.26.0 (2014-05-xx)
^^^^^^^^^^^^^^^^^^^
- Added get_bind
0.26.0 (2014-05-07)
^^^^^^^^^^^^^^^^^^^

View File

@@ -16,6 +16,12 @@ escape_like
.. autofunction:: escape_like
get_bind
^^^^^^^^
.. autofunction:: get_bind
get_columns
^^^^^^^^^^^

View File

@@ -11,6 +11,7 @@ from .functions import (
dependent_objects,
drop_database,
escape_like,
get_bind,
get_columns,
get_declarative_base,
get_primary_keys,
@@ -82,6 +83,7 @@ __all__ = (
force_instant_defaults,
generates,
generic_relationship,
get_bind,
get_columns,
get_declarative_base,
get_primary_keys,

View File

@@ -13,6 +13,7 @@ from .database import (
)
from .orm import (
dependent_objects,
get_bind,
get_columns,
get_declarative_base,
get_primary_keys,
@@ -34,6 +35,7 @@ __all__ = (
'dependent_objects',
'drop_database',
'escape_like',
'get_bind',
'get_columns',
'get_declarative_base',
'get_primary_keys',

View File

@@ -10,6 +10,7 @@ import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import UnmappedInstanceError
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session
@@ -17,6 +18,39 @@ from sqlalchemy.orm.util import AliasedInsp
from ..query_chain import QueryChain
def get_bind(obj):
"""
Return the bind for given SQLAlchemy Engine / Connection / declarative
model object.
:param obj: SQLAlchemy Engine / Connection / declarative model object
::
from sqlalchemy_utils import get_bind
get_bind(session) # Connection object
get_bind(user)
"""
if hasattr(obj, 'bind'):
conn = obj.bind
else:
try:
conn = object_session(obj).bind
except UnmappedInstanceError:
conn = obj
if not hasattr(conn, 'execute'):
raise TypeError(
'This method accepts only Session, Engine, Connection and '
'declarative model objects.'
)
return conn
def dependent_objects(obj, foreign_keys=None):
"""
Return a QueryChain that iterates through all dependent objects for given

View File

@@ -0,0 +1,21 @@
from pytest import raises
from sqlalchemy_utils import get_bind
from tests import TestCase
class TestGetBind(TestCase):
def test_with_session(self):
assert get_bind(self.session) == self.connection
def test_with_connection(self):
assert get_bind(self.connection) == self.connection
def test_with_model_object(self):
article = self.Article()
self.session.add(article)
assert get_bind(article) == self.connection
def test_with_unknown_type(self):
with raises(TypeError):
get_bind(None)