Add getdotattr
This commit is contained in:
@@ -12,37 +12,39 @@ from .database import (
|
||||
non_indexed_foreign_keys,
|
||||
)
|
||||
from .orm import (
|
||||
primary_keys,
|
||||
table_name,
|
||||
declarative_base,
|
||||
getdotattr,
|
||||
has_changes,
|
||||
identity,
|
||||
naturally_equivalent,
|
||||
remove_property
|
||||
primary_keys,
|
||||
remove_property,
|
||||
table_name,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
create_mock_engine,
|
||||
defer_except,
|
||||
mock_engine,
|
||||
sort_query,
|
||||
render_expression,
|
||||
render_statement,
|
||||
QuerySorterException,
|
||||
database_exists,
|
||||
create_database,
|
||||
create_mock_engine,
|
||||
database_exists,
|
||||
declarative_base,
|
||||
defer_except,
|
||||
drop_database,
|
||||
escape_like,
|
||||
is_auto_assigned_date_column,
|
||||
is_indexed_foreign_key,
|
||||
non_indexed_foreign_keys,
|
||||
remove_property,
|
||||
primary_keys,
|
||||
table_name,
|
||||
declarative_base,
|
||||
getdotattr,
|
||||
has_changes,
|
||||
identity,
|
||||
is_auto_assigned_date_column,
|
||||
is_indexed_foreign_key,
|
||||
mock_engine,
|
||||
naturally_equivalent,
|
||||
non_indexed_foreign_keys,
|
||||
primary_keys,
|
||||
QuerySorterException,
|
||||
remove_property,
|
||||
render_expression,
|
||||
render_statement,
|
||||
sort_query,
|
||||
table_name,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,7 +1,9 @@
|
||||
from functools import partial
|
||||
from toolz import curry, first
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
@@ -218,6 +220,32 @@ def declarative_base(model):
|
||||
return model
|
||||
|
||||
|
||||
def getdotattr(obj_or_class, dot_path):
|
||||
"""
|
||||
Allow dot-notated strings to be passed to `getattr`.
|
||||
|
||||
::
|
||||
|
||||
getdotattr(SubSection, 'section.document')
|
||||
|
||||
getdotattr(subsection, 'section.document')
|
||||
|
||||
|
||||
:param obj_or_class: Any object or class
|
||||
:param dot_path: Attribute path with dot mark as separator
|
||||
"""
|
||||
def get_attr(mixed, attr):
|
||||
if isinstance(mixed, InstrumentedAttribute):
|
||||
return getattr(
|
||||
mixed.property.mapper.class_,
|
||||
attr
|
||||
)
|
||||
else:
|
||||
return getattr(mixed, attr)
|
||||
|
||||
return six.moves.reduce(get_attr, dot_path.split('.'), obj_or_class)
|
||||
|
||||
|
||||
def has_changes(obj, attr):
|
||||
"""
|
||||
Simple shortcut function for checking if given attribute of given
|
||||
|
59
tests/functions/test_getdotattr.py
Normal file
59
tests/functions/test_getdotattr.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils.functions import getdotattr
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
|
||||
class TestTwoWayAttributeValueGeneration(TestCase):
|
||||
def create_models(self):
|
||||
class Document(self.Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
class Section(self.Base):
|
||||
__tablename__ = 'section'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
document_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Document.id)
|
||||
)
|
||||
|
||||
document = sa.orm.relationship(Document)
|
||||
|
||||
class SubSection(self.Base):
|
||||
__tablename__ = 'subsection'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
locale = sa.Column(sa.String(10))
|
||||
|
||||
section_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(Section.id)
|
||||
)
|
||||
|
||||
section = sa.orm.relationship(Section)
|
||||
|
||||
self.Document = Document
|
||||
self.Section = Section
|
||||
self.SubSection = SubSection
|
||||
|
||||
def test_getdotattr_for_objects(self):
|
||||
document = self.Document(name=u'some document')
|
||||
section = self.Section(document=document)
|
||||
subsection = self.SubSection(section=section)
|
||||
|
||||
assert getdotattr(
|
||||
subsection,
|
||||
'section.document.name'
|
||||
) == u'some document'
|
||||
|
||||
def test_getdotattr_for_class_paths(self):
|
||||
assert getdotattr(self.Section, 'document') is self.Section.document
|
||||
assert (
|
||||
getdotattr(self.SubSection, 'section.document') is
|
||||
self.Section.document
|
||||
)
|
||||
assert getdotattr(self.Section, 'document.name') is self.Document.name
|
Reference in New Issue
Block a user