Add getdotattr

This commit is contained in:
Konsta Vesterinen
2014-02-18 11:36:19 +02:00
parent 1b344ff354
commit 09c2d7ddc7
3 changed files with 107 additions and 18 deletions

View File

@@ -12,37 +12,39 @@ from .database import (
non_indexed_foreign_keys,
)
from .orm import (
primary_keys,
table_name,
declarative_base,
getdotattr,
has_changes,
identity,
naturally_equivalent,
remove_property
primary_keys,
remove_property,
table_name,
)
__all__ = (
create_mock_engine,
defer_except,
mock_engine,
sort_query,
render_expression,
render_statement,
QuerySorterException,
database_exists,
create_database,
create_mock_engine,
database_exists,
declarative_base,
defer_except,
drop_database,
escape_like,
is_auto_assigned_date_column,
is_indexed_foreign_key,
non_indexed_foreign_keys,
remove_property,
primary_keys,
table_name,
declarative_base,
getdotattr,
has_changes,
identity,
is_auto_assigned_date_column,
is_indexed_foreign_key,
mock_engine,
naturally_equivalent,
non_indexed_foreign_keys,
primary_keys,
QuerySorterException,
remove_property,
render_expression,
render_statement,
sort_query,
table_name,
)

View File

@@ -1,7 +1,9 @@
from functools import partial
from toolz import curry, first
import six
import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.util import AliasedInsp
@@ -218,6 +220,32 @@ def declarative_base(model):
return model
def getdotattr(obj_or_class, dot_path):
"""
Allow dot-notated strings to be passed to `getattr`.
::
getdotattr(SubSection, 'section.document')
getdotattr(subsection, 'section.document')
:param obj_or_class: Any object or class
:param dot_path: Attribute path with dot mark as separator
"""
def get_attr(mixed, attr):
if isinstance(mixed, InstrumentedAttribute):
return getattr(
mixed.property.mapper.class_,
attr
)
else:
return getattr(mixed, attr)
return six.moves.reduce(get_attr, dot_path.split('.'), obj_or_class)
def has_changes(obj, attr):
"""
Simple shortcut function for checking if given attribute of given

View File

@@ -0,0 +1,59 @@
import sqlalchemy as sa
from sqlalchemy_utils.functions import getdotattr
from tests import TestCase
class TestTwoWayAttributeValueGeneration(TestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
document = sa.orm.relationship(Document)
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section)
self.Document = Document
self.Section = Section
self.SubSection = SubSection
def test_getdotattr_for_objects(self):
document = self.Document(name=u'some document')
section = self.Section(document=document)
subsection = self.SubSection(section=section)
assert getdotattr(
subsection,
'section.document.name'
) == u'some document'
def test_getdotattr_for_class_paths(self):
assert getdotattr(self.Section, 'document') is self.Section.document
assert (
getdotattr(self.SubSection, 'section.document') is
self.Section.document
)
assert getdotattr(self.Section, 'document.name') is self.Document.name